Skip to content
Snippets Groups Projects

remove is_chat argument

Merged Grießhaber Daniel requested to merge remove-is-chat into refactor-models
Compare and
2 files
+ 136
90
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 131
87
import functools
import hashlib
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Callable, ClassVar
import warnings
import llama_cpp
import openai
@@ -22,7 +23,6 @@ warnings.simplefilter("once")
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
chat: bool
def __init_subclass__(cls) -> None:
if inspect.isabstract(cls):
@@ -36,25 +36,33 @@ class LLMModel(ABC):
raise ValueError("Model %s does not exist", name)
return cls.models[name](options=options, **kwargs)
@functools.lru_cache
def _compute_cache_key(self, name, **kwargs):
# we use a tuple of the model name, the options, and the kwargs as the cache key
return (name,) + tuple((key, value) for key, value in kwargs.items())
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
self.chat = options.chat
# store kwargs for caching
self.options = options
self.kwargs = kwargs
# set up caching for model calls
self._call_model_cached = None
if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key))
self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])(
self._call_model_cached
)
@cache.memoize(typed=True, ignore=["func"])
def _call_function(func, *args, **kwargs):
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
pass
def create_completion(
self,
@@ -65,42 +73,18 @@ class LLMModel(ABC):
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
chat: bool | None = None,
stop: str = None,
history: dict = None,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
max_tokens = kwargs.pop("max_tokens", self.options.max_tokens)
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
if not chat and system_message:
prompt = system_message + prompt
messages = [self._get_user_message(prompt)]
if chat:
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
model_input = {"messages": messages}
else:
model_input = {"prompt": prompt}
model_input = self.build_model_input(prompt, system_message, messages, history)
reponse, usage = self._create_completion(
chat=chat,
**model_input,
stop=stop,
max_tokens=max_tokens,
use_cache=use_cache,
**kwargs,
)
@@ -134,18 +118,26 @@ class LLMModel(ABC):
warnings.warn("Caching is disabled when a grammar is provided.")
use_cache = False
if use_cache:
# use cached function call
cache_key = self._compute_cache_key(
model_completion_fn.__name__, **self.options.__dict__, **self.kwargs
)
return self._call_model_cached(model_completion_fn, cache_key, **kwargs)
if use_cache and self._call_model_cached is not None:
return self._call_model_cached(model_completion_fn, **kwargs)
else:
return model_completion_fn(**kwargs)
def _call_model_cached(self, func, cache_key, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property
def model_cache_key(self):
unique_options_key = json.dumps(
vars(self.options),
sort_keys=True,
) + json.dumps(
self.kwargs,
sort_keys=True,
)
cache_key = (
str(self.model_name).replace("/", "_")
+ "/"
+ hashlib.sha1(unique_options_key.encode()).hexdigest()
)
return cache_key
@classmethod
@abstractmethod
@@ -205,34 +197,32 @@ class Llama(LLMModel):
# needs to be called after model is initialized
super().__init__(options=options, n_ctx=n_ctx, **kwargs)
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
if system_message is not None:
prompt = system_message + prompt
return {"prompt": prompt}
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
else:
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
usage = ModelUsage(**response["usage"])
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls.__name__} model arguments")
@@ -272,6 +262,43 @@ class Llama(LLMModel):
)
class LlamaChat(Llama):
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
usage = ModelUsage(**response["usage"])
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
[messages.insert(index, entry) for index, entry in enumerate(history)]
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
return {"messages": messages}
class OpenAI(LLMModel):
"""Queries an OpenAI model using its API."""
@@ -288,33 +315,33 @@ class OpenAI(LLMModel):
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
else:
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
return {
"prompt": prompt,
"system_message": system_message,
"messages": messages,
"history": history,
}
@classmethod
def register_arguments(cls, parser: ArgumentParser):
@@ -322,6 +349,24 @@ class OpenAI(LLMModel):
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
class OpenAiChat(OpenAI):
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
"--evolution-engine",
@@ -339,4 +384,3 @@ argument_group.add_argument(
type=int,
help="Maximum number of tokens being generated from LLM. ",
)
argument_group.add_argument("--chat", "-c", action="store_true")
Loading