From a78c704c8d7fd8c7cc31db1facf0ec1ecc9c2e46 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Mon, 19 Aug 2024 18:46:31 +0200
Subject: [PATCH] refactor llm model abstraction

---
 evoprompt/models.py | 144 ++++++++++++++++++++++----------------------
 1 file changed, 71 insertions(+), 73 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index ad93569..19eb363 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -20,6 +20,8 @@ logger = logging.getLogger(__name__)
 logging.captureWarnings(True)
 warnings.simplefilter("once")
 
+ChatMessages = list[dict[str, str]]
+
 
 class LLMModel(ABC):
     models: ClassVar[dict[str, type["LLMModel"]]] = {}
@@ -48,22 +50,13 @@ class LLMModel(ABC):
         if not options.disable_cache:
             cache = Cache(Path(".cache_dir", self.model_cache_key))
 
-            @cache.memoize(typed=True, ignore=["func"])
+            @cache.memoize(typed=True, ignore=[0, "func"])
             def _call_function(func, *args, **kwargs):
                 return func(*args, **kwargs)
 
             self._call_model_cached = _call_function
 
     @abstractmethod
-    def build_model_input(
-        self,
-        prompt: str,
-        system_message: str | None,
-        messages: list[dict[str, str]],
-        history: list[dict[str, str]] | None = None,
-    ):
-        pass
-
     def create_completion(
         self,
         system_message: str | None,
@@ -74,26 +67,9 @@ class LLMModel(ABC):
         prompt_prefix: str = "",
         prompt_suffix: str = "",
         stop: str = None,
-        history: list[dict[str, str]] | None = None,
+        history: ChatMessages | None = None,
         **kwargs: Any,
-    ) -> tuple[str, ModelUsage]:
-        # create prompt
-        prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
-        messages = [self._get_user_message(prompt)]
-        model_input, messages = self.build_model_input(
-            prompt, system_message, messages, history
-        )
-
-        reponse, usage = self._create_completion(
-            **model_input,
-            stop=stop,
-            use_cache=use_cache,
-            max_tokens=self.options.max_tokens,
-            **kwargs,
-        )
-
-        messages.append(self._get_assistant_message(reponse))
-        return reponse, messages, usage
+    ) -> tuple[str, ModelUsage]: ...
 
     def _get_user_message(self, content: str):
         return {
@@ -200,16 +176,35 @@ class Llama(LLMModel):
         # needs to be called after model is initialized
         super().__init__(options=options, n_ctx=n_ctx, **kwargs)
 
-    def build_model_input(
+    def create_completion(
         self,
-        prompt: str,
         system_message: str | None,
-        messages: list[dict[str, str]],
-        history: list[dict[str, str]] | None = None,
-    ):
+        prompt: str,
+        *,
+        use_cache: bool = False,
+        prompt_appendix: str = "",
+        prompt_prefix: str = "",
+        prompt_suffix: str = "",
+        stop: str = None,
+        history: ChatMessages | None = None,
+        **kwargs: Any,
+    ) -> tuple[str, ModelUsage]:
+        # create prompt
+        prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
+        messages = [self._get_user_message(prompt)]
         if system_message is not None:
             prompt = system_message + prompt
-        return {"prompt": prompt}, messages
+
+        reponse, usage = self._create_completion(
+            prompt=prompt,
+            stop=stop,
+            use_cache=use_cache,
+            max_tokens=self.options.max_tokens,
+            **kwargs,
+        )
+
+        messages.append(self._get_assistant_message(reponse))
+        return reponse, messages, usage
 
     def _create_completion(
         self,
@@ -265,7 +260,46 @@ class Llama(LLMModel):
         )
 
 
-class LlamaChat(Llama):
+class ChatModel:
+
+    def create_completion(
+        self,
+        system_message: str | None,
+        prompt: str,
+        *,
+        use_cache: bool = False,
+        prompt_appendix: str = "",
+        prompt_prefix: str = "",
+        prompt_suffix: str = "",
+        stop: str = None,
+        history: ChatMessages | None = None,
+        **kwargs: Any,
+    ) -> tuple[str, ModelUsage]:
+        # create prompt
+        prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
+        messages = [self._get_user_message(prompt)]
+
+        # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
+        # TODO is it better to check for a system message in the history?
+        if history is not None:
+            messages = history + messages
+            [messages.insert(index, entry) for index, entry in enumerate(history)]
+        elif system_message:
+            messages = [self._get_system_message(system_message)] + messages
+
+        reponse, usage = self._create_completion(
+            messages=messages,
+            stop=stop,
+            use_cache=use_cache,
+            max_tokens=self.options.max_tokens,
+            **kwargs,
+        )
+
+        messages.append(self._get_assistant_message(reponse))
+        return reponse, messages, usage
+
+
+class LlamaChat(Llama, ChatModel):
 
     def _create_completion(
         self,
@@ -282,26 +316,8 @@ class LlamaChat(Llama):
         usage = ModelUsage(**response["usage"])
         return response_text, usage
 
-    def build_model_input(
-        self,
-        prompt: str,
-        system_message: str | None,
-        messages: list[dict[str, str]],
-        history: list[dict[str, str]] | None = None,
-    ):
-
-        # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
-        # TODO is it better to check for a system message in the history?
-        if history is not None:
-            messages = history + messages
-            [messages.insert(index, entry) for index, entry in enumerate(history)]
-        elif system_message:
-            messages = [self._get_system_message(system_message)] + messages
 
-        return {"messages": messages}, messages
-
-
-class OpenAiChat(LLMModel):
+class OpenAiChat(LLMModel, ChatModel):
     """Queries an OpenAI model using its API."""
 
     def __init__(
@@ -330,24 +346,6 @@ class OpenAiChat(LLMModel):
         usage = ModelUsage(**response.usage.__dict__)
         return response_text, usage
 
-    def build_model_input(
-        self,
-        prompt: str,
-        system_message: str | None,
-        messages: list[dict[str, str]],
-        history: list[dict[str, str]] | None = None,
-    ):
-
-        # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
-        # TODO is it better to check for a system message in the history?
-        if history is not None:
-            messages = history + messages
-            [messages.insert(index, entry) for index, entry in enumerate(history)]
-        elif system_message:
-            messages = [self._get_system_message(system_message)] + messages
-
-        return {"messages": messages}, messages
-
     @classmethod
     def register_arguments(cls, parser: ArgumentParser):
         group = parser.add_argument_group("OpenAI model arguments")
-- 
GitLab