From b9311d44168662e01c61186a8b35cecac917b3a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Fri, 9 Aug 2024 06:38:15 +0200 Subject: [PATCH] add persistent caching to `_evaluate_sample` --- evoprompt/models.py | 18 ++++++++++++++++-- evoprompt/task/task.py | 11 ++++++++++- requirements.txt | 3 ++- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 372cef7..e29b212 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -59,6 +59,11 @@ class LLMModel(ABC): def register_arguments(cls, parser: ArgumentParser): pass + @property + @abstractmethod + def cache_key(self): + pass + class Llama(LLMModel): @@ -77,7 +82,7 @@ class Llama(LLMModel): if options.llama_path is not None: # use local file self.model = llama_cpp.Llama( - options.llama_path, + str(options.llama_path), chat_format=options.chat_format, chat_handler=options.chat_handler, verbose=options.verbose > 1 or options.llama_verbose, @@ -159,7 +164,7 @@ class Llama(LLMModel): group = parser.add_argument_group(f"{cls.__name__} model arguments") group.add_argument( "--llama-path", - type=str, + type=Path, help="Specify path to local Llama model, takes precedence over --llama-model", ), group.add_argument( @@ -192,6 +197,11 @@ class Llama(LLMModel): help="Increase verbosity of Llama model", ) + @property + def cache_key(self): + model_name = Path(self.model.model_path).stem + return f"llama-{model_name}" + class OpenAI(LLMModel): """Queries an OpenAI model using its API.""" @@ -268,6 +278,10 @@ class OpenAI(LLMModel): group = parser.add_argument_group("OpenAI model arguments") group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") + @property + def cache_key(self): + return f"openai-{self.model_name}" + argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument( diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index fa87832..36fe905 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -1,10 +1,13 @@ import logging +import shelve from abc import ABCMeta, abstractmethod from collections import deque from dataclasses import KW_ONLY, dataclass +from pathlib import Path from statistics import mean from typing import Iterable, Literal, Union +import ring from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar from tqdm import tqdm @@ -255,7 +258,7 @@ class Task(metaclass=ABCMeta): def __init__( self, - model: Union[LLMModel], + model: LLMModel, validation_dataset: str | None = None, test_dataset: str | None = None, *, @@ -290,6 +293,12 @@ class Task(metaclass=ABCMeta): if self.debug and len(self.test_dataset) > 5: self.test_dataset = self.test_dataset.select(range(5)) + # cache evaluation runs to disk in a specific shelf for each model + cache_path = Path(".cache_dir") / self.model.cache_key + cache_path.parent.mkdir(exist_ok=True, parents=True) + shelf = shelve.open(cache_path) + self._evaluate_sample = ring.shelve(shelf)(self._evaluate_sample) + def load_validation_set( self, validation_dataset: str, validation_split: str | None ): diff --git a/requirements.txt b/requirements.txt index a02036c..849f05c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ numpy datasets>=2.20 evaluate +ring llama-cpp-python tqdm graphviz @@ -9,4 +10,4 @@ openai py7zr rouge-score sacrebleu -sacremoses \ No newline at end of file +sacremoses -- GitLab