diff --git a/evoprompt/models.py b/evoprompt/models.py index c6468038d7476a6f4d7f2f66009304f2e5e769e5..87bab1067d86edb2e8805ad0ca8b55ba9a9f51b0 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,5 +1,6 @@ -import functools +import hashlib import inspect +import json import logging import warnings from abc import ABC, abstractmethod @@ -35,11 +36,6 @@ class LLMModel(ABC): raise ValueError("Model %s does not exist", name) return cls.models[name](options=options, **kwargs) - @functools.lru_cache - def _compute_cache_key(self, name, **kwargs): - # we use a tuple of the model name, the options, and the kwargs as the cache key - return (name,) + tuple((key, value) for key, value in kwargs.items()) - def __init__(self, options: Namespace, **kwargs): self.usage = ModelUsage() @@ -120,17 +116,28 @@ class LLMModel(ABC): if use_cache: # use cached function call - cache_key = self._compute_cache_key( - self.__class__.__name__, **self.options.__dict__, **self.kwargs - ) - return self._call_model_cached(model_completion_fn, cache_key, **kwargs) + return self._call_model_cached(model_completion_fn, **kwargs) else: return model_completion_fn(**kwargs) - def _call_model_cached(self, func, cache_key, *args, **kwargs): + def _call_model_cached(self, func, *args, **kwargs): # `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function return func(*args, **kwargs) + @property + def model_cache_key(self): + unique_options_key = json.dumps( + vars(self.options), + sort_keys=True, + ) + json.dumps( + self.kwargs, + sort_keys=True, + ) + cache_key = ( + str(self.model_name) + hashlib.sha1(unique_options_key.encode()).hexdigest() + ) + return cache_key + @classmethod @abstractmethod def register_arguments(cls, parser: ArgumentParser): @@ -196,7 +203,6 @@ class Llama(LLMModel): messages: list[dict[str, str]], history: list[dict[str, str]] | None = None, ): - if system_message is not None: prompt = system_message + prompt return {"prompt": prompt} @@ -216,10 +222,6 @@ class Llama(LLMModel): usage = ModelUsage(**response["usage"]) return response_text, usage - @property - def model_cache_key(self): - return self.model_name - @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group(f"{cls.__name__} model arguments") @@ -326,10 +328,6 @@ class OpenAI(LLMModel): usage = ModelUsage(**response.usage.__dict__) return response_text, usage - @property - def model_cache_key(self): - return self.model_name - @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("OpenAI model arguments")