Skip to content
Snippets Groups Projects
Commit a78c704c authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

refactor llm model abstraction

parent 66338b96
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
......@@ -20,6 +20,8 @@ logger = logging.getLogger(__name__)
logging.captureWarnings(True)
warnings.simplefilter("once")
ChatMessages = list[dict[str, str]]
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
......@@ -48,22 +50,13 @@ class LLMModel(ABC):
if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key))
@cache.memoize(typed=True, ignore=["func"])
@cache.memoize(typed=True, ignore=[0, "func"])
def _call_function(func, *args, **kwargs):
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
pass
def create_completion(
self,
system_message: str | None,
......@@ -74,26 +67,9 @@ class LLMModel(ABC):
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: list[dict[str, str]] | None = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
model_input, messages = self.build_model_input(
prompt, system_message, messages, history
)
reponse, usage = self._create_completion(
**model_input,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
) -> tuple[str, ModelUsage]: ...
def _get_user_message(self, content: str):
return {
......@@ -200,16 +176,35 @@ class Llama(LLMModel):
# needs to be called after model is initialized
super().__init__(options=options, n_ctx=n_ctx, **kwargs)
def build_model_input(
def create_completion(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
prompt: str,
*,
use_cache: bool = False,
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
if system_message is not None:
prompt = system_message + prompt
return {"prompt": prompt}, messages
reponse, usage = self._create_completion(
prompt=prompt,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
def _create_completion(
self,
......@@ -265,7 +260,46 @@ class Llama(LLMModel):
)
class LlamaChat(Llama):
class ChatModel:
def create_completion(
self,
system_message: str | None,
prompt: str,
*,
use_cache: bool = False,
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
[messages.insert(index, entry) for index, entry in enumerate(history)]
elif system_message:
messages = [self._get_system_message(system_message)] + messages
reponse, usage = self._create_completion(
messages=messages,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
class LlamaChat(Llama, ChatModel):
def _create_completion(
self,
......@@ -282,26 +316,8 @@ class LlamaChat(Llama):
usage = ModelUsage(**response["usage"])
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
[messages.insert(index, entry) for index, entry in enumerate(history)]
elif system_message:
messages = [self._get_system_message(system_message)] + messages
return {"messages": messages}, messages
class OpenAiChat(LLMModel):
class OpenAiChat(LLMModel, ChatModel):
"""Queries an OpenAI model using its API."""
def __init__(
......@@ -330,24 +346,6 @@ class OpenAiChat(LLMModel):
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
[messages.insert(index, entry) for index, entry in enumerate(history)]
elif system_message:
messages = [self._get_system_message(system_message)] + messages
return {"messages": messages}, messages
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment