Skip to content
Snippets Groups Projects
Commit 13b82620 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

refactor _call_model_cached function implementation

parent 7420ba00
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
...@@ -44,11 +44,16 @@ class LLMModel(ABC): ...@@ -44,11 +44,16 @@ class LLMModel(ABC):
self.kwargs = kwargs self.kwargs = kwargs
# set up caching for model calls # set up caching for model calls
self._call_model_cached = None
if not options.disable_cache: if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key)) cache = Cache(Path(".cache_dir", self.model_cache_key))
self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])(
self._call_model_cached @cache.memoize(typed=True, ignore=["func"])
) def _call_function(func, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod @abstractmethod
def build_model_input( def build_model_input(
...@@ -114,16 +119,11 @@ class LLMModel(ABC): ...@@ -114,16 +119,11 @@ class LLMModel(ABC):
warnings.warn("Caching is disabled when a grammar is provided.") warnings.warn("Caching is disabled when a grammar is provided.")
use_cache = False use_cache = False
if use_cache: if use_cache and self._call_model_cached is not None:
# use cached function call
return self._call_model_cached(model_completion_fn, **kwargs) return self._call_model_cached(model_completion_fn, **kwargs)
else: else:
return model_completion_fn(**kwargs) return model_completion_fn(**kwargs)
def _call_model_cached(self, func, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property @property
def model_cache_key(self): def model_cache_key(self):
unique_options_key = json.dumps( unique_options_key = json.dumps(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment