From a68e64155eb010ca446b55a89198a51ab28dd6c0 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Fri, 16 Aug 2024 12:43:08 +0200 Subject: [PATCH] Add caching functionality for model calls --- evoprompt/models.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 53d1ce0..bc821d1 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,3 +1,4 @@ +import functools import inspect import logging from abc import ABC, abstractmethod @@ -5,6 +6,7 @@ from argparse import ArgumentParser, Namespace from pathlib import Path from typing import Any, Callable, ClassVar +import joblib import llama_cpp import openai @@ -15,6 +17,15 @@ from evoprompt.utils import get_seed logger = logging.getLogger(__name__) +mem = joblib.Memory(location=".cache_dir", verbose=0) + + +@mem.cache(ignore=["func"]) +def cached_function_call(func, cache_key, *args, **kwargs): + # `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function + return func(*args, **kwargs) + + class LLMModel(ABC): models: ClassVar[dict[str, type["LLMModel"]]] = {} chat: bool @@ -31,10 +42,31 @@ class LLMModel(ABC): raise ValueError("Model %s does not exist", name) return cls.models[name](options=options, **kwargs) + def _compute_cache_key(self, name): + # we use a tuple of the model name, the options, and the kwargs as the cache key + return ( + (name,) + + tuple((key, value) for key, value in self.options.__dict__.items()) + + tuple((key, value) for key, value in self.kwargs.items()) + ) + # NOTE this implementation would be better but it does not produce deterministic results, maybe with another cache library/backend? + # return functools._make_key( + # ( + # (name,) + # + tuple((key, value) for key, value in self.options.__dict__.items()) + # ), + # self.kwargs, + # typed=True, + # ) + def __init__(self, options: Namespace, **kwargs): self.usage = ModelUsage() self.chat = options.chat + # store kwargs for caching + self.options = options + self.kwargs = kwargs + def create_completion( self, system_message: str | None, @@ -113,8 +145,12 @@ class LLMModel(ABC): # logger.warnning use_cache = False - # TODO implement caching - return model_completion_fn(**kwargs) + if use_cache: + # use cached function call + cache_key = self._compute_cache_key(model_completion_fn.__name__) + return cached_function_call(model_completion_fn, cache_key, **kwargs) + else: + return model_completion_fn(**kwargs) @classmethod @abstractmethod -- GitLab