From b9311d44168662e01c61186a8b35cecac917b3a0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 9 Aug 2024 06:38:15 +0200
Subject: [PATCH] add persistent caching to `_evaluate_sample`

---
 evoprompt/models.py    | 18 ++++++++++++++++--
 evoprompt/task/task.py | 11 ++++++++++-
 requirements.txt       |  3 ++-
 3 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 372cef7..e29b212 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -59,6 +59,11 @@ class LLMModel(ABC):
     def register_arguments(cls, parser: ArgumentParser):
         pass
 
+    @property
+    @abstractmethod
+    def cache_key(self):
+        pass
+
 
 class Llama(LLMModel):
 
@@ -77,7 +82,7 @@ class Llama(LLMModel):
         if options.llama_path is not None:
             # use local file
             self.model = llama_cpp.Llama(
-                options.llama_path,
+                str(options.llama_path),
                 chat_format=options.chat_format,
                 chat_handler=options.chat_handler,
                 verbose=options.verbose > 1 or options.llama_verbose,
@@ -159,7 +164,7 @@ class Llama(LLMModel):
         group = parser.add_argument_group(f"{cls.__name__} model arguments")
         group.add_argument(
             "--llama-path",
-            type=str,
+            type=Path,
             help="Specify path to local Llama model, takes precedence over --llama-model",
         ),
         group.add_argument(
@@ -192,6 +197,11 @@ class Llama(LLMModel):
             help="Increase verbosity of Llama model",
         )
 
+    @property
+    def cache_key(self):
+        model_name = Path(self.model.model_path).stem
+        return f"llama-{model_name}"
+
 
 class OpenAI(LLMModel):
     """Queries an OpenAI model using its API."""
@@ -268,6 +278,10 @@ class OpenAI(LLMModel):
         group = parser.add_argument_group("OpenAI model arguments")
         group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
 
+    @property
+    def cache_key(self):
+        return f"openai-{self.model_name}"
+
 
 argument_group = argument_parser.add_argument_group("Model arguments")
 argument_group.add_argument(
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index fa87832..36fe905 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -1,10 +1,13 @@
 import logging
+import shelve
 from abc import ABCMeta, abstractmethod
 from collections import deque
 from dataclasses import KW_ONLY, dataclass
+from pathlib import Path
 from statistics import mean
 from typing import Iterable, Literal, Union
 
+import ring
 from datasets import Dataset, load_dataset
 from llama_cpp import LlamaGrammar
 from tqdm import tqdm
@@ -255,7 +258,7 @@ class Task(metaclass=ABCMeta):
 
     def __init__(
         self,
-        model: Union[LLMModel],
+        model: LLMModel,
         validation_dataset: str | None = None,
         test_dataset: str | None = None,
         *,
@@ -290,6 +293,12 @@ class Task(metaclass=ABCMeta):
         if self.debug and len(self.test_dataset) > 5:
             self.test_dataset = self.test_dataset.select(range(5))
 
+        # cache evaluation runs to disk in a specific shelf for each model
+        cache_path = Path(".cache_dir") / self.model.cache_key
+        cache_path.parent.mkdir(exist_ok=True, parents=True)
+        shelf = shelve.open(cache_path)
+        self._evaluate_sample = ring.shelve(shelf)(self._evaluate_sample)
+
     def load_validation_set(
         self, validation_dataset: str, validation_split: str | None
     ):
diff --git a/requirements.txt b/requirements.txt
index a02036c..849f05c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
 numpy
 datasets>=2.20
 evaluate
+ring
 llama-cpp-python
 tqdm
 graphviz
@@ -9,4 +10,4 @@ openai
 py7zr
 rouge-score
 sacrebleu
-sacremoses
\ No newline at end of file
+sacremoses
-- 
GitLab