From a78c704c8d7fd8c7cc31db1facf0ec1ecc9c2e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Mon, 19 Aug 2024 18:46:31 +0200 Subject: [PATCH] refactor llm model abstraction --- evoprompt/models.py | 144 ++++++++++++++++++++++---------------------- 1 file changed, 71 insertions(+), 73 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index ad93569..19eb363 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -20,6 +20,8 @@ logger = logging.getLogger(__name__) logging.captureWarnings(True) warnings.simplefilter("once") +ChatMessages = list[dict[str, str]] + class LLMModel(ABC): models: ClassVar[dict[str, type["LLMModel"]]] = {} @@ -48,22 +50,13 @@ class LLMModel(ABC): if not options.disable_cache: cache = Cache(Path(".cache_dir", self.model_cache_key)) - @cache.memoize(typed=True, ignore=["func"]) + @cache.memoize(typed=True, ignore=[0, "func"]) def _call_function(func, *args, **kwargs): return func(*args, **kwargs) self._call_model_cached = _call_function @abstractmethod - def build_model_input( - self, - prompt: str, - system_message: str | None, - messages: list[dict[str, str]], - history: list[dict[str, str]] | None = None, - ): - pass - def create_completion( self, system_message: str | None, @@ -74,26 +67,9 @@ class LLMModel(ABC): prompt_prefix: str = "", prompt_suffix: str = "", stop: str = None, - history: list[dict[str, str]] | None = None, + history: ChatMessages | None = None, **kwargs: Any, - ) -> tuple[str, ModelUsage]: - # create prompt - prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix - messages = [self._get_user_message(prompt)] - model_input, messages = self.build_model_input( - prompt, system_message, messages, history - ) - - reponse, usage = self._create_completion( - **model_input, - stop=stop, - use_cache=use_cache, - max_tokens=self.options.max_tokens, - **kwargs, - ) - - messages.append(self._get_assistant_message(reponse)) - return reponse, messages, usage + ) -> tuple[str, ModelUsage]: ... def _get_user_message(self, content: str): return { @@ -200,16 +176,35 @@ class Llama(LLMModel): # needs to be called after model is initialized super().__init__(options=options, n_ctx=n_ctx, **kwargs) - def build_model_input( + def create_completion( self, - prompt: str, system_message: str | None, - messages: list[dict[str, str]], - history: list[dict[str, str]] | None = None, - ): + prompt: str, + *, + use_cache: bool = False, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + stop: str = None, + history: ChatMessages | None = None, + **kwargs: Any, + ) -> tuple[str, ModelUsage]: + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] if system_message is not None: prompt = system_message + prompt - return {"prompt": prompt}, messages + + reponse, usage = self._create_completion( + prompt=prompt, + stop=stop, + use_cache=use_cache, + max_tokens=self.options.max_tokens, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, messages, usage def _create_completion( self, @@ -265,7 +260,46 @@ class Llama(LLMModel): ) -class LlamaChat(Llama): +class ChatModel: + + def create_completion( + self, + system_message: str | None, + prompt: str, + *, + use_cache: bool = False, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + stop: str = None, + history: ChatMessages | None = None, + **kwargs: Any, + ) -> tuple[str, ModelUsage]: + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] + + # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case + # TODO is it better to check for a system message in the history? + if history is not None: + messages = history + messages + [messages.insert(index, entry) for index, entry in enumerate(history)] + elif system_message: + messages = [self._get_system_message(system_message)] + messages + + reponse, usage = self._create_completion( + messages=messages, + stop=stop, + use_cache=use_cache, + max_tokens=self.options.max_tokens, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, messages, usage + + +class LlamaChat(Llama, ChatModel): def _create_completion( self, @@ -282,26 +316,8 @@ class LlamaChat(Llama): usage = ModelUsage(**response["usage"]) return response_text, usage - def build_model_input( - self, - prompt: str, - system_message: str | None, - messages: list[dict[str, str]], - history: list[dict[str, str]] | None = None, - ): - - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message in the history? - if history is not None: - messages = history + messages - [messages.insert(index, entry) for index, entry in enumerate(history)] - elif system_message: - messages = [self._get_system_message(system_message)] + messages - return {"messages": messages}, messages - - -class OpenAiChat(LLMModel): +class OpenAiChat(LLMModel, ChatModel): """Queries an OpenAI model using its API.""" def __init__( @@ -330,24 +346,6 @@ class OpenAiChat(LLMModel): usage = ModelUsage(**response.usage.__dict__) return response_text, usage - def build_model_input( - self, - prompt: str, - system_message: str | None, - messages: list[dict[str, str]], - history: list[dict[str, str]] | None = None, - ): - - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message in the history? - if history is not None: - messages = history + messages - [messages.insert(index, entry) for index, entry in enumerate(history)] - elif system_message: - messages = [self._get_system_message(system_message)] + messages - - return {"messages": messages}, messages - @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("OpenAI model arguments") -- GitLab