diff --git a/evoprompt/models.py b/evoprompt/models.py index 78dd7ed39e5e34a5b971eb666e1e289fce3fcd4f..3eb592503c0f799bf4048ea91706f7c4d321b910 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -44,11 +44,16 @@ class LLMModel(ABC): self.kwargs = kwargs # set up caching for model calls + self._call_model_cached = None if not options.disable_cache: cache = Cache(Path(".cache_dir", self.model_cache_key)) - self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])( - self._call_model_cached - ) + + @cache.memoize(typed=True, ignore=["func"]) + def _call_function(func, *args, **kwargs): + # `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function + return func(*args, **kwargs) + + self._call_model_cached = _call_function @abstractmethod def build_model_input( @@ -114,16 +119,11 @@ class LLMModel(ABC): warnings.warn("Caching is disabled when a grammar is provided.") use_cache = False - if use_cache: - # use cached function call + if use_cache and self._call_model_cached is not None: return self._call_model_cached(model_completion_fn, **kwargs) else: return model_completion_fn(**kwargs) - def _call_model_cached(self, func, *args, **kwargs): - # `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function - return func(*args, **kwargs) - @property def model_cache_key(self): unique_options_key = json.dumps(