From f57cbe077d97f825da356ccbf3ce05e6c0575fce Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Thu, 10 Oct 2024 16:12:30 +0200
Subject: [PATCH] Use chat format for evolution demonstration samples

---
 evoprompt/evolution/evolution.py | 259 +++++++++++++++++++++----------
 evoprompt/evolution/template.py  |   6 -
 2 files changed, 173 insertions(+), 92 deletions(-)
 delete mode 100644 evoprompt/evolution/template.py

diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index ce57de5..b5d71f9 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -1,3 +1,4 @@
+from collections.abc import Iterable
 import logging
 import re
 from abc import ABCMeta, abstractmethod
@@ -7,7 +8,6 @@ from tqdm import trange
 import wandb
 import weave
 
-from evoprompt.evolution.template import get_demonstration_prompt_template
 from evoprompt.evolution.template_de import (
     DE_DEMONSTRATION_DATA_CLS,
     DE_DEMONSTRATION_DATA_SIM,
@@ -91,6 +91,42 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
     ) -> tuple[str, list[Judgement], ModelUsage]:
         pass
 
+    def build_demonstration_prompt(
+        self,
+        demonstration_samples: Iterable[tuple[str, str]],
+        instruction: str = None,
+    ) -> ChatMessages:
+        return self.evolution_model.build_demonstration_data(
+            demonstration_samples,
+            instruction=instruction,
+        )
+
+    def parse_response(
+        self,
+        response: str,
+        start_token: str = "<prompt>",
+        end_token: str = "</prompt>",
+        allow_missing_end_token: bool = True,
+    ):
+        matches = re.findall(
+            # regex that matches any characters between last pair of `start_token` and `end_token`, and optionally allow missing `end_token`
+            rf"{start_token}(?!.*{start_token})(?:(.*){end_token}{"|(.*)" if allow_missing_end_token else ""})",
+            response,
+            flags=(re.IGNORECASE | re.DOTALL),
+        )
+        if matches and any(matches[0]):
+            # there is always only a single match, and one group should be non-empty
+            if matches[0][0]:
+                evolved_prompt = matches[0][0]
+            else:
+                assert matches[0][1]
+                evolved_prompt = matches[0][1]
+
+        else:
+            # could not extract prompt -> no new prompt
+            evolved_prompt = None
+        return evolved_prompt
+
     @abstractmethod
     def update(self, *args, **kwargs):
         pass
@@ -245,20 +281,23 @@ class GeneticAlgorithm(EvolutionAlgorithm):
         # Based on this two-step process, we design instructions, guiding LLMs to
         # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
 
-        filled_prompt = self.get_prompt_template().format(
+        prompt = self._get_prompt_template()
+        if isinstance(prompt, tuple) and len(prompt) > 1:
+            # extract demonstrations
+            prompt, demo_messages = prompt
+        filled_prompt = prompt.format(
             prompt1=prompt_1,
             prompt2=prompt_2,
         )
-        evolved_prompt, history, recent_turn, usage = (
-            self.evolution_model.create_completion(
-                system_message=SYSTEM_MESSAGE,
-                prompt=filled_prompt,
-                use_randomness=True,
-            )
+        response, history, recent_turn, usage = self.evolution_model.create_completion(
+            system_message=SYSTEM_MESSAGE,
+            prompt=filled_prompt,
+            history=demo_messages if self.use_evolution_demo else None,
+            use_randomness=True,
         )
 
         judgement = self.judge_and_correct_step(
-            filled_prompt, evolved_prompt, history, recent_turn
+            filled_prompt, response, history, recent_turn
         )
 
         if judgement.skip:
@@ -269,17 +308,18 @@ class GeneticAlgorithm(EvolutionAlgorithm):
                 usage,
             )
 
-        evolved_prompt = judgement.corrected_response
-
-        if "<prompt>" in evolved_prompt:
-            evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
+        response = judgement.corrected_response
+        evolved_prompt = self.parse_response(response)
 
-        logger.info(
-            "GA-evolved prompts '%s' and '%s' into '%s'",
-            prompt_1,
-            prompt_2,
-            evolved_prompt,
-        )
+        if evolved_prompt is None:
+            logger.info(f"Could not extract prompt from response: {evolved_prompt}")
+        else:
+            logger.info(
+                "GA-evolved prompts '%s' and '%s' into '%s'",
+                prompt_1,
+                prompt_2,
+                evolved_prompt,
+            )
 
         return evolved_prompt, [judgement], usage
 
@@ -305,22 +345,28 @@ class GeneticAlgorithm(EvolutionAlgorithm):
 
         return retained_prompts
 
-    def get_prompt_template(self):
+    def _get_prompt_template(self):
         if self.use_evolution_demo:
             if isinstance(
                 self.task, (TextClassification, Summarization, QuestionAnswering)
             ):
-                return get_demonstration_prompt_template(
-                    GA_PROMPT, GA_DEMONSTRATION_DATA_SIM
-                )
+                demonstration_data = GA_DEMONSTRATION_DATA_SIM
             elif isinstance(self.task, Simplification):
-                return get_demonstration_prompt_template(
-                    GA_PROMPT, GA_DEMONSTRATION_DATA_CLS
-                )
+                demonstration_data = GA_DEMONSTRATION_DATA_CLS
             else:
                 raise NotImplementedError(
                     f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
                 )
+
+            prompt_with_demonstration_data = self.build_demonstration_prompt(
+                [
+                    (
+                        GA_PROMPT.format(**demonstration_data),
+                        demonstration_data["response"],
+                    )
+                ]
+            )
+            return GA_PROMPT, prompt_with_demonstration_data
         return GA_PROMPT
 
 
@@ -346,22 +392,25 @@ class DifferentialEvolution(EvolutionAlgorithm):
             prompts_current_evolution, key=lambda prompt: prompt.score
         )
 
-        filled_prompt = self.get_prompt_template().format(
+        prompt = self._get_prompt_template()
+        if isinstance(prompt, tuple) and len(prompt) > 1:
+            # extract demonstrations
+            prompt, demo_messages = prompt
+        filled_prompt = prompt.format(
             prompt1=prompt_1,
             prompt2=prompt_2,
             prompt3=best_prompt_current_evolution,
             basic_prompt=prompts_current_evolution[current_iteration],
         )
-        evolved_prompt, history, recent_turn, usage = (
-            self.evolution_model.create_completion(
-                system_message=SYSTEM_MESSAGE,
-                prompt=filled_prompt,
-                use_randomness=True,
-            )
+        response, history, recent_turn, usage = self.evolution_model.create_completion(
+            system_message=SYSTEM_MESSAGE,
+            prompt=filled_prompt,
+            history=demo_messages if self.use_evolution_demo else None,
+            use_randomness=True,
         )
 
         judgement = self.judge_and_correct_step(
-            filled_prompt, evolved_prompt, history, recent_turn
+            filled_prompt, response, history, recent_turn
         )
 
         if judgement.skip:
@@ -372,33 +421,20 @@ class DifferentialEvolution(EvolutionAlgorithm):
                 usage,
             )
 
-        evolved_prompt = judgement.corrected_response
+        response = judgement.corrected_response
+        evolved_prompt = self.parse_response(response)
 
-        matches = re.findall(
-            # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing
-            r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))",
-            evolved_prompt,
-            flags=(re.IGNORECASE | re.DOTALL),
-        )
-        if matches and any(matches[0]):
-            # there is always only a single match, and one group should be non-empty
-            if matches[0][0]:
-                evolved_prompt = matches[0][0]
-            else:
-                assert matches[0][1]
-                evolved_prompt = matches[0][1]
+        if evolved_prompt is None:
+            logger.info(f"Could not extract prompt from response: {evolved_prompt}")
         else:
-            # TODO what to do in this case? Discard generated prompt directly?
-            pass
-
-        logger.info(
-            "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
-            prompt_1,
-            prompt_2,
-            best_prompt_current_evolution,
-            prompts_current_evolution[current_iteration],
-            evolved_prompt,
-        )
+            logger.info(
+                "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
+                prompt_1,
+                prompt_2,
+                best_prompt_current_evolution,
+                prompts_current_evolution[current_iteration],
+                evolved_prompt,
+            )
 
         return evolved_prompt, [judgement], usage
 
@@ -416,22 +452,28 @@ class DifferentialEvolution(EvolutionAlgorithm):
         ]
         return population
 
-    def get_prompt_template(self):
+    def _get_prompt_template(self):
         if self.use_evolution_demo:
             if isinstance(
                 self.task, (TextClassification, Summarization, QuestionAnswering)
             ):
-                return get_demonstration_prompt_template(
-                    DE_PROMPT, DE_DEMONSTRATION_DATA_SIM
-                )
+                demonstration_data = DE_DEMONSTRATION_DATA_SIM
             elif isinstance(self.task, Simplification):
-                return get_demonstration_prompt_template(
-                    DE_PROMPT, DE_DEMONSTRATION_DATA_CLS
-                )
+                demonstration_data = DE_DEMONSTRATION_DATA_CLS
             else:
                 raise NotImplementedError(
                     f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
                 )
+
+            prompt_with_demonstration_data = self.build_demonstration_prompt(
+                [
+                    (
+                        DE_PROMPT.format(**demonstration_data),
+                        demonstration_data["response"],
+                    )
+                ]
+            )
+            return DE_PROMPT, prompt_with_demonstration_data
         return DE_PROMPT
 
 
@@ -457,27 +499,46 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
             prompts_current_evolution, key=lambda prompt: prompt.score
         )
 
-        history = None
+        # list of evolution steps
+        evolutions_steps = []
+        # list (turns) of list (demonstrations)
+        demos = [[]]
         response: str = ""
         judgements: list[Judgement] = []
         usage: ModelUsage = ModelUsage()
-        for idx, prompt in enumerate(self.get_prompt_template()):
+        for idx, prompt in enumerate(self._get_prompt_template()):
+            if isinstance(prompt, tuple) and len(prompt) > 1:
+                # extract demonstrations
+                prompt, demo_messages = prompt
+                demos[-1].extend(demo_messages)
+            if self.use_evolution_demo:
+                messages_demos = self.condense_messages(demos[-1])
+            else:
+                messages_demos = []
             filled_prompt = prompt.format(
                 prompt1=prompt_1,
                 prompt2=prompt_2,
                 prompt3=best_prompt_current_evolution,
                 basic_prompt=prompts_current_evolution[current_iteration],
             )
+            evolutions_steps.append(
+                self.evolution_model._get_user_message(filled_prompt)
+            )
+            # TODO Shall we still use only a single turn containing all messages if we do not use demonstrations for evolution?
+            messages = self.condense_messages(evolutions_steps)
             response, history, recent_turn, usage = (
                 self.evolution_model.create_completion(
                     system_message=SYSTEM_MESSAGE,
-                    prompt=filled_prompt,
-                    history=history,
+                    messages=messages,
+                    history=messages_demos,
                     # the models often repeat the instuction which could also contain </prompt> therefore we should not stop early
-                    stop=None, # "</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None,
+                    stop=None,
                     use_randomness=True,
                 )
             )
+            evolutions_steps.append(
+                self.evolution_model._get_assistant_message(response)
+            )
             logger.debug(
                 "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s",
                 idx,
@@ -504,22 +565,39 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
             # update history with recent turn
             history += recent_turn
 
-        # at this point we should get a new prompt
-        if "<prompt>" in response:
-            response = response.split("<prompt>")[1].split("</prompt>")[0]
+        evolved_prompt = self.parse_response(response)
 
-        logger.info(
-            "DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
-            prompt_1,
-            prompt_2,
-            best_prompt_current_evolution,
-            prompts_current_evolution[current_iteration],
-            response,
-        )
+        if evolved_prompt is None:
+            logger.info(f"Could not extract prompt from response: {evolved_prompt}")
+        else:
+            logger.info(
+                "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
+                prompt_1,
+                prompt_2,
+                best_prompt_current_evolution,
+                prompts_current_evolution[current_iteration],
+                evolved_prompt,
+            )
+
+        return evolved_prompt, judgements, usage
+
+    def condense_messages(self, messages: list[ChatMessages]) -> list[dict]:
+        if not messages:
+            return []
 
-        return response, judgements, usage
+        if messages[-1]["role"] == "assistant":
+            assistant_turn = messages[-1]
+            messages = messages[:-1]
+        else:
+            assistant_turn = None
+
+        user_turn = "\n\n".join(message["content"] for message in messages)
+        messages = [self.evolution_model._get_user_message(user_turn)]
+        if assistant_turn is not None:
+            messages.append(assistant_turn)
+        return messages
 
-    def get_prompt_template(self):
+    def _get_prompt_template(self):
         if self.use_evolution_demo:
             if isinstance(
                 self.task, (TextClassification, Summarization, QuestionAnswering)
@@ -532,9 +610,18 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
                     f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
                 )
 
-            for prompt, demonstration_data_item in zip(
+            for prompt_template, demonstration_data_item in zip(
                 DE_COT_PROMPTS, demonstration_data
             ):
-                yield get_demonstration_prompt_template(prompt, demonstration_data_item)
+                prompt_with_demonstration_data = self.build_demonstration_prompt(
+                    [
+                        (
+                            prompt_template.format(**demonstration_data_item),
+                            demonstration_data_item["response"],
+                        )
+                    ]
+                )
+                yield prompt_template, prompt_with_demonstration_data
+                # TODO how can we add a generation prefix for the model?
         else:
             yield from DE_COT_PROMPTS
diff --git a/evoprompt/evolution/template.py b/evoprompt/evolution/template.py
deleted file mode 100644
index a84dbb3..0000000
--- a/evoprompt/evolution/template.py
+++ /dev/null
@@ -1,6 +0,0 @@
-def get_demonstration_prompt_template(prompt_template: str, demonstration_data: dict):
-    prompt_template_with_demo = prompt_template.format(**demonstration_data)
-    prompt_template_with_demo += "\n\n" + demonstration_data["response"]
-    prompt_template_with_demo += "\n\n" + prompt_template
-    prompt_template_with_demo += "\n\n" + demonstration_data["generation_prefix"]
-    return prompt_template_with_demo
-- 
GitLab