From 33f58a73a664111ac4e60ce53e5cba9ee85c25d1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Mon, 19 Aug 2024 14:15:54 +0200
Subject: [PATCH] fix build_model_input for OpenAI models to use same layout as
 the LLama models

---
 evoprompt/models.py | 27 +++++++++++++++++++++------
 1 file changed, 21 insertions(+), 6 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 204b91f..563979c 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -337,12 +337,9 @@ class OpenAI(LLMModel):
         messages: list[dict[str, str]],
         history: list[dict[str, str]] | None = None,
     ):
-        return {
-            "prompt": prompt,
-            "system_message": system_message,
-            "messages": messages,
-            "history": history,
-        }, messages
+        if system_message is not None:
+            prompt = system_message + prompt
+        return {"prompt": prompt}, messages
 
     @classmethod
     def register_arguments(cls, parser: ArgumentParser):
@@ -367,6 +364,24 @@ class OpenAiChat(OpenAI):
         usage = ModelUsage(**response.usage.__dict__)
         return response_text, usage
 
+    def build_model_input(
+        self,
+        prompt: str,
+        system_message: str | None,
+        messages: list[dict[str, str]],
+        history: list[dict[str, str]] | None = None,
+    ):
+
+        # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
+        # TODO is it better to check for a system message in the history?
+        if history is not None:
+            messages = history + messages
+            [messages.insert(index, entry) for index, entry in enumerate(history)]
+        elif system_message:
+            messages = [self._get_system_message(system_message)] + messages
+
+        return {"messages": messages}, messages
+
 
 argument_group = argument_parser.add_argument_group("Model arguments")
 argument_group.add_argument(
-- 
GitLab