From f021dfcb1444286e3c57626d1082d8f5a427fc4b Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Fri, 6 Sep 2024 18:38:54 +0200
Subject: [PATCH] Allow user to skip evolution if response was judged as bad

---
 evoprompt/evolution/evolution.py | 67 +++++++++++++++++++++++---------
 evoprompt/optimization.py        | 29 +++++++++++---
 2 files changed, 72 insertions(+), 24 deletions(-)

diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index dbf8079..0ec9785 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -159,25 +159,27 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
                     )
                     self.total_evolution_usage += evolution_usage
 
-                    prompt_source = (
-                        "corrected"
-                        if not all(j.happy for j in judgements)
-                        else "generated"
-                    )
-                    evolved_prompt = self.add_prompt(
-                        p_i,
-                        parents=(pr1, pr2),
-                        meta={
-                            "gen": t,
-                            "source": prompt_source,
-                            "judgements": judgements,
-                        },
-                    )
-                    self.total_evaluation_usage += evolved_prompt.usage
-                    self.log_prompt(evolved_prompt, t, i)
-
-                    new_evolutions.append(evolved_prompt)
-                    self.save_snapshot()
+                    # If a prompt is None, it means that the prompt was skipped
+                    if p_i is not None:
+                        prompt_source = (
+                            "corrected"  # could also mean that user skipped the prompt
+                            if not all(j.happy for j in judgements)
+                            else "generated"
+                        )
+                        evolved_prompt = self.add_prompt(
+                            p_i,
+                            parents=(pr1, pr2),
+                            meta={
+                                "gen": t,
+                                "source": prompt_source,
+                                "judgements": judgements,
+                            },
+                        )
+                        self.total_evaluation_usage += evolved_prompt.usage
+                        self.log_prompt(evolved_prompt, t, i)
+
+                        new_evolutions.append(evolved_prompt)
+                        self.save_snapshot()
                 # Line 6: Update based on the evaluation scores
                 # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
                 new_population = self.update(new_evolutions, prompts_current_evolution)
@@ -249,6 +251,15 @@ class GeneticAlgorithm(EvolutionAlgorithm):
         )
 
         judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages)
+
+        if judgement.skip:
+            # skip this prompt, for GA this is ok since during the update step we consider all prompts keeping the population size constant
+            return (
+                None,
+                [judgement],
+                usage,
+            )
+
         evolved_prompt = judgement.corrected_response
 
         if "<prompt>" in evolved_prompt:
@@ -338,6 +349,15 @@ class DifferentialEvolution(EvolutionAlgorithm):
         )
 
         judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages)
+
+        if judgement.skip:
+            # skip this prompt, for DE this means using the basic prompt
+            return (
+                prompts_current_evolution[current_iteration].content,
+                [judgement],
+                usage,
+            )
+
         evolved_prompt = judgement.corrected_response
 
         matches = re.findall(
@@ -450,6 +470,15 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
                 filled_prompt, response, history=messages
             )
             judgements.append(judgement)
+
+            if judgement.skip:
+                # skip this prompt, for DE this means using the basic prompt
+                return (
+                    prompts_current_evolution[current_iteration].content,
+                    judgements,
+                    usage,
+                )
+
             # replace last message with corrected response
             messages[-1]["content"] = judgement.corrected_response
             response = judgement.corrected_response
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index e1aad33..c1bf378 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -32,6 +32,7 @@ class Judgement(NamedTuple):
     original_response: str
     corrected_response: str
     happy: bool | None
+    skip: bool
 
 
 class PromptMeta(TypedDict):
@@ -43,8 +44,9 @@ class PromptMeta(TypedDict):
 class ResponseEditor(App):
     BINDINGS = [
         Binding(
-            key="ctrl+q", action="quit", description="Finish Editing & Save Prompt"
+            key="ctrl+s", action="quit", description="Finish Editing & Save Prompt"
         ),
+        Binding(key="ctrl+n", action="skip", description="Skip Prompt"),
     ]
 
     def __init__(
@@ -58,6 +60,7 @@ class ResponseEditor(App):
         self.response = original_response
         self.history = history
         self.judge_response = judge_response
+        self.skip = False  # used to mark the prompt as skipped
         super().__init__()
 
     def compose(self) -> ComposeResult:
@@ -79,6 +82,10 @@ class ResponseEditor(App):
         )
         yield Footer()
 
+    async def action_skip(self):
+        self.skip = True
+        await self.action_quit()
+
     @property
     def modified_response(self):
         return self.text_area.text
@@ -337,7 +344,7 @@ class PromptOptimization:
         self, instruction: str, response: str, history: ChatMessages
     ) -> Judgement:
         if self.judge_model is None:
-            return Judgement(response, response, happy=None)
+            return Judgement(response, response, happy=None, skip=False)
 
         # judge the actual response
         prompt = f"Instruction: {instruction}\nResponse: {response}"
@@ -369,10 +376,11 @@ class PromptOptimization:
         )
 
         if judge_happy:
-            return Judgement(response, response, happy=True)
+            return Judgement(response, response, happy=True, skip=False)
 
-        logger.info(f"Prompt judged as bad. Letting User change the prompt.")
+        logger.info(f"Prompt judged as bad, letting user take action.")
 
+        # let user skip or correct the response in an interactive way
         editor = ResponseEditor(
             instruction,
             response,
@@ -381,7 +389,18 @@ class PromptOptimization:
         )
         editor.run()
 
-        return Judgement(response, editor.modified_response, happy=False)
+        if editor.skip:
+            logger.info("User skipped prompt.")
+        else:
+            logger.info(
+                "User corrected prompt:\n'%s'\n -> \n'%s'",
+                response,
+                editor.modified_response,
+            )
+
+        return Judgement(
+            response, editor.modified_response, happy=False, skip=editor.skip
+        )
 
 
 def load_snapshot(path: Path):
-- 
GitLab