Skip to content
Snippets Groups Projects

remove is_chat argument

Merged Grießhaber Daniel requested to merge remove-is-chat into refactor-models
Compare and Show latest version
2 files
+ 32
27
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 26
15
@@ -44,11 +44,15 @@ class LLMModel(ABC):
self.kwargs = kwargs
# set up caching for model calls
self._call_model_cached = None
if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key))
self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])(
self._call_model_cached
)
@cache.memoize(typed=True, ignore=["func"])
def _call_function(func, *args, **kwargs):
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod
def build_model_input(
@@ -73,9 +77,6 @@ class LLMModel(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
@@ -117,16 +118,11 @@ class LLMModel(ABC):
warnings.warn("Caching is disabled when a grammar is provided.")
use_cache = False
if use_cache:
# use cached function call
if use_cache and self._call_model_cached is not None:
return self._call_model_cached(model_completion_fn, **kwargs)
else:
return model_completion_fn(**kwargs)
def _call_model_cached(self, func, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property
def model_cache_key(self):
unique_options_key = json.dumps(
@@ -137,7 +133,9 @@ class LLMModel(ABC):
sort_keys=True,
)
cache_key = (
str(self.model_name) + hashlib.sha1(unique_options_key.encode()).hexdigest()
str(self.model_name).replace("/", "_")
+ "/"
+ hashlib.sha1(unique_options_key.encode()).hexdigest()
)
return cache_key
@@ -331,13 +329,27 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
return {
"prompt": prompt,
"system_message": system_message,
"messages": messages,
"history": history,
}
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
def OpenAiChat(OpenAI):
class OpenAiChat(OpenAI):
def _create_completion(
self,
@@ -372,4 +384,3 @@ argument_group.add_argument(
type=int,
help="Maximum number of tokens being generated from LLM. ",
)
argument_group.add_argument("--chat", "-c", action="store_true")
Loading