From 2df18f7fd25dc3a31e7479ceb1febd4c899f498c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 16 Aug 2024 17:39:53 +0200
Subject: [PATCH] remove is_chat argument

---
 evoprompt/models.py | 158 ++++++++++++++++++++++++++------------------
 main.py             |  14 +++-
 2 files changed, 106 insertions(+), 66 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 3131fff..c646803 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -1,11 +1,11 @@
 import functools
 import inspect
 import logging
+import warnings
 from abc import ABC, abstractmethod
 from argparse import ArgumentParser, Namespace
 from pathlib import Path
 from typing import Any, Callable, ClassVar
-import warnings
 
 import llama_cpp
 import openai
@@ -22,7 +22,6 @@ warnings.simplefilter("once")
 
 class LLMModel(ABC):
     models: ClassVar[dict[str, type["LLMModel"]]] = {}
-    chat: bool
 
     def __init_subclass__(cls) -> None:
         if inspect.isabstract(cls):
@@ -43,7 +42,6 @@ class LLMModel(ABC):
 
     def __init__(self, options: Namespace, **kwargs):
         self.usage = ModelUsage()
-        self.chat = options.chat
 
         # store kwargs for caching
         self.options = options
@@ -56,6 +54,16 @@ class LLMModel(ABC):
                 self._call_model_cached
             )
 
+    @abstractmethod
+    def build_model_input(
+        self,
+        prompt: str,
+        system_message: str | None,
+        messages: list[dict[str, str]],
+        history: list[dict[str, str]] | None = None,
+    ):
+        pass
+
     def create_completion(
         self,
         system_message: str | None,
@@ -65,42 +73,18 @@ class LLMModel(ABC):
         prompt_appendix: str = "",
         prompt_prefix: str = "",
         prompt_suffix: str = "",
-        chat: bool | None = None,
         stop: str = None,
-        history: dict = None,
+        history: list[dict[str, str]] | None = None,
         **kwargs: Any,
-    ) -> tuple[str, ModelUsage]:
-        if chat is None:
-            chat = self.chat
-        max_tokens = kwargs.pop("max_tokens", self.options.max_tokens)
-
+    ) -> tuple[str, list[dict[str, str]], ModelUsage]:
         # create prompt
         prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
-
-        if not chat and system_message:
-            prompt = system_message + prompt
-
         messages = [self._get_user_message(prompt)]
-
-        if chat:
-            # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
-            # TODO is it better to check for a system message in the history?
-            if history is not None:
-                messages = history + messages
-            elif system_message:
-                messages.insert(
-                    0,
-                    self._get_system_message(system_message),
-                )
-            model_input = {"messages": messages}
-        else:
-            model_input = {"prompt": prompt}
+        model_input = self.build_model_input(prompt, system_message, messages, history)
 
         reponse, usage = self._create_completion(
-            chat=chat,
             **model_input,
             stop=stop,
-            max_tokens=max_tokens,
             use_cache=use_cache,
             **kwargs,
         )
@@ -137,7 +121,7 @@ class LLMModel(ABC):
         if use_cache:
             # use cached function call
             cache_key = self._compute_cache_key(
-                model_completion_fn.__name__, **self.options.__dict__, **self.kwargs
+                self.__class__.__name__, **self.options.__dict__, **self.kwargs
             )
             return self._call_model_cached(model_completion_fn, cache_key, **kwargs)
         else:
@@ -205,26 +189,29 @@ class Llama(LLMModel):
         # needs to be called after model is initialized
         super().__init__(options=options, n_ctx=n_ctx, **kwargs)
 
+    def build_model_input(
+        self,
+        prompt: str,
+        system_message: str | None,
+        messages: list[dict[str, str]],
+        history: list[dict[str, str]] | None = None,
+    ):
+
+        if system_message is not None:
+            prompt = system_message + prompt
+        return {"prompt": prompt}
+
     def _create_completion(
         self,
-        chat: bool,
         use_cache: bool = False,
         **kwargs,
     ):
-        if chat:
-            response = self._call_model(
-                self.model.create_chat_completion,
-                use_cache=use_cache,
-                **kwargs,
-            )
-            response_text = response["choices"][0]["message"]["content"]
-        else:
-            response = self._call_model(
-                self.model.create_completion,
-                use_cache=use_cache,
-                **kwargs,
-            )
-            response_text = response["choices"][0]["text"]
+        response = self._call_model(
+            self.model.create_completion,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        response_text = response["choices"][0]["text"]
 
         usage = ModelUsage(**response["usage"])
         return response_text, usage
@@ -272,6 +259,43 @@ class Llama(LLMModel):
         )
 
 
+class LlamaChat(Llama):
+
+    def _create_completion(
+        self,
+        use_cache: bool = False,
+        **kwargs,
+    ):
+        response = self._call_model(
+            self.model.create_chat_completion,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        response_text = response["choices"][0]["message"]["content"]
+
+        usage = ModelUsage(**response["usage"])
+        return response_text, usage
+
+    def build_model_input(
+        self,
+        prompt: str,
+        system_message: str | None,
+        messages: list[dict[str, str]],
+        history: list[dict[str, str]] | None = None,
+    ):
+
+        # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
+        # TODO is it better to check for a system message in the history?
+        if history is not None:
+            [messages.insert(index, entry) for index, entry in enumerate(history)]
+        elif system_message:
+            messages.insert(
+                0,
+                self._get_system_message(system_message),
+            )
+        return {"messages": messages}
+
+
 class OpenAI(LLMModel):
     """Queries an OpenAI model using its API."""
 
@@ -288,26 +312,16 @@ class OpenAI(LLMModel):
 
     def _create_completion(
         self,
-        chat: bool,
         use_cache: bool = False,
         **kwargs,
     ):
-        if chat:
-            response = self._call_model(
-                self.openai_client.chat.completions.create,
-                model=self.model_name,
-                use_cache=use_cache,
-                **kwargs,
-            )
-            response_text = response.choices[0].message.content
-        else:
-            response = self._call_model(
-                self.openai_client.completions.create,
-                model=self.model,
-                use_cache=use_cache,
-                **kwargs,
-            )
-            response_text = response.choices[0].text
+        response = self._call_model(
+            self.openai_client.completions.create,
+            model=self.model,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        response_text = response.choices[0].text
 
         usage = ModelUsage(**response.usage.__dict__)
         return response_text, usage
@@ -322,6 +336,24 @@ class OpenAI(LLMModel):
         group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
 
 
+def OpenAiChat(OpenAI):
+
+    def _create_completion(
+        self,
+        use_cache: bool = False,
+        **kwargs,
+    ):
+        response = self._call_model(
+            self.openai_client.chat.completions.create,
+            model=self.model_name,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        response_text = response.choices[0].message.content
+        usage = ModelUsage(**response.usage.__dict__)
+        return response_text, usage
+
+
 argument_group = argument_parser.add_argument_group("Model arguments")
 argument_group.add_argument(
     "--evolution-engine",
diff --git a/main.py b/main.py
index 3ed6132..90bac4b 100644
--- a/main.py
+++ b/main.py
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
 
 from evoprompt.cli import argument_parser
 from evoprompt.evolution import get_optimizer_class
-from evoprompt.models import Llama, LLMModel
+from evoprompt.models import Llama, LlamaChat, LLMModel
 from evoprompt.task import get_task
 from evoprompt.utils import init_rng, setup_console_logger
 
@@ -62,7 +62,12 @@ if __name__ == "__main__":
         logger.info("DEBUG mode: Do a quick run")
 
     # set up evolution model
-    evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
+    evolution_model_name = (
+        (options.evolution_engine + "chat")
+        if options.chat
+        else options.evolution_engine
+    )
+    evolution_model = LLMModel.get_model(evolution_model_name, options=options)
 
     match options.evolution_engine:
         case "llama":
@@ -79,7 +84,10 @@ if __name__ == "__main__":
         case "llama":
             evaluation_model = evolution_model
         case "openai":
-            evaluation_model = Llama(options)
+            if not options.chat:
+                evaluation_model = Llama(options)
+            else:
+                evaluation_model = LlamaChat(options)
 
     task = get_task(options.task, evaluation_model, **options.__dict__)
     logger.info(f"Running with task {task.__class__.__name__}")
-- 
GitLab