Skip to content
Snippets Groups Projects

remove is_chat argument

Merged Grießhaber Daniel requested to merge remove-is-chat into refactor-models
Compare and
2 files
+ 135
107
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 130
104
import functools
import hashlib
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Callable, ClassVar
import warnings
import llama_cpp
import openai
@@ -19,10 +20,11 @@ logger = logging.getLogger(__name__)
logging.captureWarnings(True)
warnings.simplefilter("once")
ChatMessages = list[dict[str, str]]
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
chat: bool
def __init_subclass__(cls) -> None:
if inspect.isabstract(cls):
@@ -36,26 +38,25 @@ class LLMModel(ABC):
raise ValueError("Model %s does not exist", name)
return cls.models[name](options=options, **kwargs)
@functools.lru_cache
def _compute_cache_key(self, name, **kwargs):
# we use a tuple of the model name, the options, and the kwargs as the cache key
return (name,) + tuple((key, value) for key, value in kwargs.items())
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
self.chat = options.chat
# store kwargs for caching
self.options = options
self.kwargs = kwargs
# set up caching for model calls
self._call_model_cached = None
if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key))
self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])(
self._call_model_cached
)
@cache.memoize(typed=True, ignore=[0, "func"])
def _call_function(func, *args, **kwargs):
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod
def create_completion(
self,
system_message: str | None,
@@ -65,48 +66,10 @@ class LLMModel(ABC):
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
chat: bool | None = None,
stop: str = None,
history: dict = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
max_tokens = kwargs.pop("max_tokens", self.options.max_tokens)
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
if not chat and system_message:
prompt = system_message + prompt
messages = [self._get_user_message(prompt)]
if chat:
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
model_input = {"messages": messages}
else:
model_input = {"prompt": prompt}
reponse, usage = self._create_completion(
chat=chat,
**model_input,
stop=stop,
max_tokens=max_tokens,
use_cache=use_cache,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
) -> tuple[str, ModelUsage]: ...
def _get_user_message(self, content: str):
return {
@@ -134,18 +97,26 @@ class LLMModel(ABC):
warnings.warn("Caching is disabled when a grammar is provided.")
use_cache = False
if use_cache:
# use cached function call
cache_key = self._compute_cache_key(
model_completion_fn.__name__, **self.options.__dict__, **self.kwargs
)
return self._call_model_cached(model_completion_fn, cache_key, **kwargs)
if use_cache and self._call_model_cached is not None:
return self._call_model_cached(model_completion_fn, **kwargs)
else:
return model_completion_fn(**kwargs)
def _call_model_cached(self, func, cache_key, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property
def model_cache_key(self):
unique_options_key = json.dumps(
vars(self.options),
sort_keys=True,
) + json.dumps(
self.kwargs,
sort_keys=True,
)
cache_key = (
str(self.model_name).replace("/", "_")
+ "/"
+ hashlib.sha1(unique_options_key.encode()).hexdigest()
)
return cache_key
@classmethod
@abstractmethod
@@ -205,34 +176,51 @@ class Llama(LLMModel):
# needs to be called after model is initialized
super().__init__(options=options, n_ctx=n_ctx, **kwargs)
def create_completion(
self,
system_message: str | None,
prompt: str,
*,
use_cache: bool = False,
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
if system_message is not None:
prompt = system_message + prompt
reponse, usage = self._create_completion(
prompt=prompt,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
else:
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
usage = ModelUsage(**response["usage"])
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls.__name__} model arguments")
@@ -272,7 +260,61 @@ class Llama(LLMModel):
)
class OpenAI(LLMModel):
class ChatModel:
def create_completion(
self,
system_message: str | None,
prompt: str,
*,
use_cache: bool = False,
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is None and system_message:
history = [self._get_system_message(system_message)]
reponse, usage = self._create_completion(
messages=messages,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, history + messages, usage
class LlamaChat(ChatModel, Llama):
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
usage = ModelUsage(**response["usage"])
return response_text, usage
class OpenAiChat(ChatModel, LLMModel):
"""Queries an OpenAI model using its API."""
def __init__(
@@ -288,34 +330,19 @@ class OpenAI(LLMModel):
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
else:
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
@@ -339,4 +366,3 @@ argument_group.add_argument(
type=int,
help="Maximum number of tokens being generated from LLM. ",
)
argument_group.add_argument("--chat", "-c", action="store_true")
Loading