Skip to content
Snippets Groups Projects
Commit a68e6415 authored by Max Kimmich's avatar Max Kimmich
Browse files

Add caching functionality for model calls

parent d685a040
No related branches found
No related tags found
1 merge request!1Refactor models
import functools
import inspect
import logging
from abc import ABC, abstractmethod
......@@ -5,6 +6,7 @@ from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Callable, ClassVar
import joblib
import llama_cpp
import openai
......@@ -15,6 +17,15 @@ from evoprompt.utils import get_seed
logger = logging.getLogger(__name__)
mem = joblib.Memory(location=".cache_dir", verbose=0)
@mem.cache(ignore=["func"])
def cached_function_call(func, cache_key, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
chat: bool
......@@ -31,10 +42,31 @@ class LLMModel(ABC):
raise ValueError("Model %s does not exist", name)
return cls.models[name](options=options, **kwargs)
def _compute_cache_key(self, name):
# we use a tuple of the model name, the options, and the kwargs as the cache key
return (
(name,)
+ tuple((key, value) for key, value in self.options.__dict__.items())
+ tuple((key, value) for key, value in self.kwargs.items())
)
# NOTE this implementation would be better but it does not produce deterministic results, maybe with another cache library/backend?
# return functools._make_key(
# (
# (name,)
# + tuple((key, value) for key, value in self.options.__dict__.items())
# ),
# self.kwargs,
# typed=True,
# )
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
self.chat = options.chat
# store kwargs for caching
self.options = options
self.kwargs = kwargs
def create_completion(
self,
system_message: str | None,
......@@ -113,8 +145,12 @@ class LLMModel(ABC):
# logger.warnning
use_cache = False
# TODO implement caching
return model_completion_fn(**kwargs)
if use_cache:
# use cached function call
cache_key = self._compute_cache_key(model_completion_fn.__name__)
return cached_function_call(model_completion_fn, cache_key, **kwargs)
else:
return model_completion_fn(**kwargs)
@classmethod
@abstractmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment