From 8355c2cbbe46fcf1799c6e9640581c685a4da718 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Fri, 20 Sep 2024 12:13:22 +0200
Subject: [PATCH] Allow to skip bad evolutions based on judge model

---
 evoprompt/models.py       | 12 --------
 evoprompt/optimization.py | 61 ++++++++++++++++++++++++++++++++-------
 2 files changed, 50 insertions(+), 23 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index de88e78..8f717a1 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -399,18 +399,6 @@ class OpenAiChat(ChatModel, LLMModel):
 
 
 argument_group = argument_parser.add_argument_group("Model arguments")
-argument_group.add_argument(
-    "--evolution-engine",
-    type=str,
-    choices=LLMModel.registered_models.keys(),
-    default="llama",
-)
-argument_group.add_argument(
-    "--judge-engine",
-    type=str,
-    default=None,
-    choices=LLMModel.registered_models.keys(),
-)
 argument_group.add_argument(
     "--disable-cache",
     action="store_true",
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index 70b3139..8153870 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -1,20 +1,22 @@
-from collections import OrderedDict
 import json
 import logging
-from pathlib import Path
 import re
+from collections import OrderedDict
+from difflib import Differ
+from pathlib import Path
 from typing import Any, Literal, NamedTuple, Optional, TypedDict
 
-from textual.app import App, ComposeResult
-from textual.binding import Binding
-from textual.containers import ScrollableContainer
-from textual.widgets import Collapsible, Footer, Label, TextArea, Static
+import wandb
 from rich.panel import Panel
 from rich.rule import Rule
 from tabulate import tabulate
+from textual.app import App, ComposeResult
+from textual.binding import Binding
+from textual.containers import ScrollableContainer
+from textual.widgets import Collapsible, Footer, Label, Static, TextArea
 from tqdm import tqdm, trange
-import wandb
 
+from evoprompt.cli import argument_parser
 from evoprompt.models import ChatMessages, LLMModel
 from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt
 from evoprompt.task import Task
@@ -160,6 +162,8 @@ class PromptOptimization:
         self.judge_model = judge_model
         self.run_options = run_options
 
+        self.skip_bad_evolutions = run_options.get("skip_bad_evolutions", False)
+
     def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None):
         parent_histories = (
             [parent.evaluation_history for parent in parents]
@@ -359,6 +363,7 @@ class PromptOptimization:
             for message in history
             if message["role"] in ["user", "assistant"]
         )
+        # TODO What if the history does not exist (is empty), i.e., for the first step in de-cot?
         prompt = (
             f"Context: {history_str}\nInstruction: {instruction}\nResponse: {response}"
         )
@@ -369,6 +374,7 @@ class PromptOptimization:
             "Wrap the answer with tags <judgement> and </judgement>. "
             "Please also add an explanation for your judgement."
         )
+        # input(f"System message:\n{system_message}\n\nPrompt:\n{prompt}\n")
         judgement_response, _, _, _ = self.judge_model.create_completion(
             system_message=system_message,
             prompt=prompt,
@@ -392,9 +398,13 @@ class PromptOptimization:
         if judge_happy:
             return Judgement(response, response, happy=True, skip=False)
 
-        logger.info(f"Prompt judged as bad, letting user take action.")
+        if self.skip_bad_evolutions:
+            # skip samples judged as bad
+            logger.info(f"Prompt judged as bad, skipping.")
+            return Judgement(response, response, happy=False, skip=True)
 
         # let user skip or correct the response in an interactive way
+        logger.info(f"Prompt judged as bad, letting user take action.")
         editor = ResponseEditor(
             instruction,
             response,
@@ -406,10 +416,19 @@ class PromptOptimization:
         if editor.skip:
             logger.info("User skipped prompt.")
         else:
+            delta = Differ().compare(
+                response.splitlines(), editor.modified_response.splitlines()
+            )
             logger.info(
-                "User corrected prompt:\n'%s'\n -> \n'%s'",
-                response,
-                editor.modified_response,
+                "User corrected prompt (delta):\n%s",
+                "\n".join(
+                    line
+                    for line in delta
+                    if line.startswith("+") or line.startswith("-")
+                ),
+                # "User corrected prompt:\n'%s'\n -> \n'%s'",
+                # response,
+                # editor.modified_response,
             )
 
         return Judgement(
@@ -429,3 +448,23 @@ def load_snapshot(path: Path):
         snapshot["T"],
         snapshot["N"],
     )
+
+
+argument_group = argument_parser.add_argument_group("Optimization arguments")
+argument_group.add_argument(
+    "--evolution-engine",
+    type=str,
+    choices=LLMModel.registered_models.keys(),
+    default="llama",
+)
+argument_group.add_argument(
+    "--judge-engine",
+    type=str,
+    default=None,
+    choices=LLMModel.registered_models.keys(),
+)
+argument_parser.add_argument(
+    "--skip-bad-evolutions",
+    action="store_true",
+    help="Skip bad evolutions as judged by the judge engine (needs judge engine to be set)",
+)
-- 
GitLab