From 543d62471e8793185ce062071bdbfdf6139ed835 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Wed, 21 Aug 2024 11:44:24 +0200 Subject: [PATCH] Fix chat model not taking into account history --- evoprompt/evolution.py | 4 ++-- evoprompt/models.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index ae909f2..49050b5 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -314,10 +314,10 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): ) messages = None - for idx, prompt in enumerate(DE_COT_PROMPTS): + for idx, prompt_template in enumerate(DE_COT_PROMPTS): response, messages, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=prompt.format( + prompt=prompt_template.format( prompt1=prompt_1, prompt2=prompt_2, prompt3=best_prompt_current_evolution, diff --git a/evoprompt/models.py b/evoprompt/models.py index f922169..21a3e15 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -275,14 +275,18 @@ class ChatModel: history: ChatMessages | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: - # create prompt - prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix - messages = [self._get_user_message(prompt)] - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case # TODO is it better to check for a system message in the history? if history is None and system_message: - history = [self._get_system_message(system_message)] + messages = [self._get_system_message(system_message)] + elif history is not None: + messages = history + else: + messages = [] + + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages += [self._get_user_message(prompt)] reponse, usage = self._create_completion( messages=messages, @@ -293,7 +297,7 @@ class ChatModel: ) messages.append(self._get_assistant_message(reponse)) - return reponse, history + messages, usage + return reponse, messages, usage class LlamaChat(ChatModel, Llama): -- GitLab