From a68e64155eb010ca446b55a89198a51ab28dd6c0 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Fri, 16 Aug 2024 12:43:08 +0200
Subject: [PATCH] Add caching functionality for model calls

---
 evoprompt/models.py | 40 ++++++++++++++++++++++++++++++++++++++--
 1 file changed, 38 insertions(+), 2 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 53d1ce0..bc821d1 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -1,3 +1,4 @@
+import functools
 import inspect
 import logging
 from abc import ABC, abstractmethod
@@ -5,6 +6,7 @@ from argparse import ArgumentParser, Namespace
 from pathlib import Path
 from typing import Any, Callable, ClassVar
 
+import joblib
 import llama_cpp
 import openai
 
@@ -15,6 +17,15 @@ from evoprompt.utils import get_seed
 logger = logging.getLogger(__name__)
 
 
+mem = joblib.Memory(location=".cache_dir", verbose=0)
+
+
+@mem.cache(ignore=["func"])
+def cached_function_call(func, cache_key, *args, **kwargs):
+    # `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
+    return func(*args, **kwargs)
+
+
 class LLMModel(ABC):
     models: ClassVar[dict[str, type["LLMModel"]]] = {}
     chat: bool
@@ -31,10 +42,31 @@ class LLMModel(ABC):
             raise ValueError("Model %s does not exist", name)
         return cls.models[name](options=options, **kwargs)
 
+    def _compute_cache_key(self, name):
+        # we use a tuple of the model name, the options, and the kwargs as the cache key
+        return (
+            (name,)
+            + tuple((key, value) for key, value in self.options.__dict__.items())
+            + tuple((key, value) for key, value in self.kwargs.items())
+        )
+        # NOTE this implementation would be better but it does not produce deterministic results, maybe with another cache library/backend?
+        # return functools._make_key(
+        #     (
+        #         (name,)
+        #         + tuple((key, value) for key, value in self.options.__dict__.items())
+        #     ),
+        #     self.kwargs,
+        #     typed=True,
+        # )
+
     def __init__(self, options: Namespace, **kwargs):
         self.usage = ModelUsage()
         self.chat = options.chat
 
+        # store kwargs for caching
+        self.options = options
+        self.kwargs = kwargs
+
     def create_completion(
         self,
         system_message: str | None,
@@ -113,8 +145,12 @@ class LLMModel(ABC):
             # logger.warnning
             use_cache = False
 
-        # TODO implement caching
-        return model_completion_fn(**kwargs)
+        if use_cache:
+            # use cached function call
+            cache_key = self._compute_cache_key(model_completion_fn.__name__)
+            return cached_function_call(model_completion_fn, cache_key, **kwargs)
+        else:
+            return model_completion_fn(**kwargs)
 
     @classmethod
     @abstractmethod
-- 
GitLab