From b47cdc459f6f64b338324121281032311f55678e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Mon, 19 Aug 2024 11:05:29 +0200
Subject: [PATCH] fix error for openai model after refactoring

---
 evoprompt/models.py |  9 ++++-----
 main.py             | 18 ++++++------------
 2 files changed, 10 insertions(+), 17 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 1349359..2d65147 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -73,9 +73,6 @@ class LLMModel(ABC):
         history: list[dict[str, str]] | None = None,
         **kwargs: Any,
     ) -> tuple[str, ModelUsage]:
-        if chat is None:
-            chat = self.chat
-
         # create prompt
         prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
         messages = [self._get_user_message(prompt)]
@@ -331,13 +328,16 @@ class OpenAI(LLMModel):
         usage = ModelUsage(**response.usage.__dict__)
         return response_text, usage
 
+    def build_model_input(self, **kwargs):
+        return kwargs
+
     @classmethod
     def register_arguments(cls, parser: ArgumentParser):
         group = parser.add_argument_group("OpenAI model arguments")
         group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
 
 
-def OpenAiChat(OpenAI):
+class OpenAiChat(OpenAI):
 
     def _create_completion(
         self,
@@ -372,4 +372,3 @@ argument_group.add_argument(
     type=int,
     help="Maximum number of tokens being generated from LLM. ",
 )
-argument_group.add_argument("--chat", "-c", action="store_true")
diff --git a/main.py b/main.py
index 90bac4b..e0b8fae 100644
--- a/main.py
+++ b/main.py
@@ -61,13 +61,8 @@ if __name__ == "__main__":
     if debug:
         logger.info("DEBUG mode: Do a quick run")
 
-    # set up evolution model
-    evolution_model_name = (
-        (options.evolution_engine + "chat")
-        if options.chat
-        else options.evolution_engine
-    )
-    evolution_model = LLMModel.get_model(evolution_model_name, options=options)
+    # # set up evolution model
+    evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
 
     match options.evolution_engine:
         case "llama":
@@ -81,13 +76,12 @@ if __name__ == "__main__":
     logger.info("Using Llama as the evaluation engine")
     evaluation_model: LLMModel
     match options.evolution_engine:
-        case "llama":
+        case "llama" | "llamachat":
             evaluation_model = evolution_model
         case "openai":
-            if not options.chat:
-                evaluation_model = Llama(options)
-            else:
-                evaluation_model = LlamaChat(options)
+            evaluation_model = Llama(options)
+        case "openaichat":
+            evaluation_model = LlamaChat(options)
 
     task = get_task(options.task, evaluation_model, **options.__dict__)
     logger.info(f"Running with task {task.__class__.__name__}")
-- 
GitLab