From 543d62471e8793185ce062071bdbfdf6139ed835 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Wed, 21 Aug 2024 11:44:24 +0200
Subject: [PATCH] Fix chat model not taking into account history

---
 evoprompt/evolution.py |  4 ++--
 evoprompt/models.py    | 16 ++++++++++------
 2 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py
index ae909f2..49050b5 100644
--- a/evoprompt/evolution.py
+++ b/evoprompt/evolution.py
@@ -314,10 +314,10 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
         )
 
         messages = None
-        for idx, prompt in enumerate(DE_COT_PROMPTS):
+        for idx, prompt_template in enumerate(DE_COT_PROMPTS):
             response, messages, usage = self.evolution_model.create_completion(
                 system_message=SYSTEM_MESSAGE,
-                prompt=prompt.format(
+                prompt=prompt_template.format(
                     prompt1=prompt_1,
                     prompt2=prompt_2,
                     prompt3=best_prompt_current_evolution,
diff --git a/evoprompt/models.py b/evoprompt/models.py
index f922169..21a3e15 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -275,14 +275,18 @@ class ChatModel:
         history: ChatMessages | None = None,
         **kwargs: Any,
     ) -> tuple[str, ModelUsage]:
-        # create prompt
-        prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
-        messages = [self._get_user_message(prompt)]
-
         # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
         # TODO is it better to check for a system message in the history?
         if history is None and system_message:
-            history = [self._get_system_message(system_message)]
+            messages = [self._get_system_message(system_message)]
+        elif history is not None:
+            messages = history
+        else:
+            messages = []
+
+        # create prompt
+        prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
+        messages += [self._get_user_message(prompt)]
 
         reponse, usage = self._create_completion(
             messages=messages,
@@ -293,7 +297,7 @@ class ChatModel:
         )
 
         messages.append(self._get_assistant_message(reponse))
-        return reponse, history + messages, usage
+        return reponse, messages, usage
 
 
 class LlamaChat(ChatModel, Llama):
-- 
GitLab