Skip to content
Snippets Groups Projects
Commit 3ae988c3 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

remove need for cache key computation during runtime

parent 2df18f7f
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
import functools
import hashlib
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
......@@ -35,11 +36,6 @@ class LLMModel(ABC):
raise ValueError("Model %s does not exist", name)
return cls.models[name](options=options, **kwargs)
@functools.lru_cache
def _compute_cache_key(self, name, **kwargs):
# we use a tuple of the model name, the options, and the kwargs as the cache key
return (name,) + tuple((key, value) for key, value in kwargs.items())
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
......@@ -120,17 +116,28 @@ class LLMModel(ABC):
if use_cache:
# use cached function call
cache_key = self._compute_cache_key(
self.__class__.__name__, **self.options.__dict__, **self.kwargs
)
return self._call_model_cached(model_completion_fn, cache_key, **kwargs)
return self._call_model_cached(model_completion_fn, **kwargs)
else:
return model_completion_fn(**kwargs)
def _call_model_cached(self, func, cache_key, *args, **kwargs):
def _call_model_cached(self, func, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property
def model_cache_key(self):
unique_options_key = json.dumps(
vars(self.options),
sort_keys=True,
) + json.dumps(
self.kwargs,
sort_keys=True,
)
cache_key = (
str(self.model_name) + hashlib.sha1(unique_options_key.encode()).hexdigest()
)
return cache_key
@classmethod
@abstractmethod
def register_arguments(cls, parser: ArgumentParser):
......@@ -196,7 +203,6 @@ class Llama(LLMModel):
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
if system_message is not None:
prompt = system_message + prompt
return {"prompt": prompt}
......@@ -216,10 +222,6 @@ class Llama(LLMModel):
usage = ModelUsage(**response["usage"])
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls.__name__} model arguments")
......@@ -326,10 +328,6 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment