From e751763b46b133c407097471b0ab636e79778fbd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Tue, 20 Aug 2024 08:42:28 +0200
Subject: [PATCH] add interactive editor to correct badly judged reponses

---
 evoprompt/api/backend.py  |   1 -
 evoprompt/evolution.py    |  51 +++++++++++++-----
 evoprompt/models.py       |  41 +++++----------
 evoprompt/optimization.py | 108 ++++++++++++++++++++------------------
 evoprompt/utils.py        |   2 +-
 5 files changed, 109 insertions(+), 94 deletions(-)

diff --git a/evoprompt/api/backend.py b/evoprompt/api/backend.py
index 608d26b..2c9e0eb 100644
--- a/evoprompt/api/backend.py
+++ b/evoprompt/api/backend.py
@@ -74,7 +74,6 @@ class MultiProcessOptimizer:
         self.model_exec = None
 
     def _setup_models(self) -> None:
-        print("setup models")
         if self._evolution_model is not None:
             raise Exception("Evolution model has already been initialized.")
 
diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py
index ec67579..2b4e6c7 100644
--- a/evoprompt/evolution.py
+++ b/evoprompt/evolution.py
@@ -8,7 +8,7 @@ from tqdm import trange
 from evoprompt.cli import argument_parser
 from evoprompt.models import LLMModel
 from evoprompt.opt_types import ModelUsage, Prompt
-from evoprompt.optimization import PromptOptimization
+from evoprompt.optimization import Judgement, PromptOptimization
 from evoprompt.task import Task
 from evoprompt.template_de import get_de_prompt_template
 from evoprompt.utils import get_all_subclasses, get_rng, log_calls
@@ -85,7 +85,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
         *,
         prompts_current_evolution: list[Prompt],
         current_iteration: int,
-    ):
+    ) -> tuple[str, list[Judgement], ModelUsage]:
         pass
 
     @abstractmethod
@@ -116,7 +116,11 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
 
                 # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
                 # p′i ←Evo(pr1,...,prk)
-                p_i, evolution_usage = self.evolve(
+                (
+                    p_i,
+                    judgements,
+                    evolution_usage,
+                ) = self.evolve(
                     pr1,
                     pr2,
                     prompts_current_evolution=prompts_current_evolution,
@@ -124,7 +128,14 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
                 )
                 self.total_evolution_usage += evolution_usage
 
-                evolved_prompt = self.add_prompt(p_i, (pr1, pr2), {"gen": t})
+                prompt_source = (
+                    "corrected" if not all(j.happy for j in judgements) else "generated"
+                )
+                evolved_prompt = self.add_prompt(
+                    p_i,
+                    parents=(pr1, pr2),
+                    meta={"gen": t, "source": prompt_source, "judgements": judgements},
+                )
                 self.total_evaluation_usage += evolved_prompt.usage
 
                 new_evolutions.append(evolved_prompt)
@@ -171,7 +182,7 @@ class GeneticAlgorithm(EvolutionAlgorithm):
         prompt_1: str,
         prompt_2: str,
         **kwargs,
-    ) -> tuple[str, ModelUsage]:
+    ):
         # Following the evolutionary operators in GA, a new candidate prompt is generated through
         # a two-step process based on the selected two parents:
         # 1) The parent prompts undergo crossover, resulting in a new prompt that
@@ -186,6 +197,10 @@ class GeneticAlgorithm(EvolutionAlgorithm):
             system_message=SYSTEM_MESSAGE,
             prompt=filled_prompt,
         )
+
+        judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages)
+        evolved_prompt = judgement.corrected_response
+
         if "<prompt>" in evolved_prompt:
             evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
 
@@ -195,9 +210,8 @@ class GeneticAlgorithm(EvolutionAlgorithm):
             prompt_2,
             evolved_prompt,
         )
-        self.judge_step(filled_prompt, evolved_prompt, messages)
 
-        return evolved_prompt, usage
+        return evolved_prompt, [judgement], usage
 
     @log_calls("Performing update for GA")
     def update(
@@ -255,6 +269,10 @@ class DifferentialEvolution(EvolutionAlgorithm):
             system_message=SYSTEM_MESSAGE,
             prompt=filled_prompt,
         )
+
+        judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages)
+        evolved_prompt = judgement.corrected_response
+
         matches = re.findall(
             # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing
             r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))",
@@ -281,9 +299,7 @@ class DifferentialEvolution(EvolutionAlgorithm):
             evolved_prompt,
         )
 
-        self.judge_step(filled_prompt, evolved_prompt, messages)
-
-        return evolved_prompt, usage
+        return evolved_prompt, [judgement], usage
 
     @log_calls("Performing update for DE")
     def update(
@@ -322,6 +338,9 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
         )
 
         messages = None
+        response: str = ""
+        judgements: list[Judgement] = []
+        usage: ModelUsage = ModelUsage()
         for idx, prompt in enumerate(DE_COT_PROMPTS):
             filled_prompt = prompt.format(
                 prompt1=prompt_1,
@@ -341,9 +360,13 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
                 messages,
                 response,
             )
-            self.judge_step(filled_prompt, response, history=messages)
-            # input(messages)
-            # input(response)
+            judgement = self.judge_and_correct_step(
+                filled_prompt, response, history=messages
+            )
+            judgements.append(judgement)
+            # replace last message with corrected response
+            messages[-1]["content"] = judgement.corrected_response
+            response = judgement.corrected_response
 
         # at this point we should get a new prompt
         if "<prompt>" in response:
@@ -358,7 +381,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
             response,
         )
 
-        return response, usage
+        return response, judgements, usage
 
 
 def get_all_subclasses(cls):
diff --git a/evoprompt/models.py b/evoprompt/models.py
index b2877e7..a10005d 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -63,24 +63,11 @@ class LLMModel(ABC):
         prompt: str,
         *,
         use_cache: bool = False,
-        stop: str = None,
+        stop: str | None = None,
         history: ChatMessages | None = None,
         **kwargs: Any,
-    ) -> tuple[str, list[dict[str, str]], ModelUsage]:
-        # create prompt
-        prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
-        messages = [self._get_user_message(prompt)]
-        model_input = self.build_model_input(prompt, system_message, messages, history)
-
-        reponse, usage = self._create_completion(
-            **model_input,
-            stop=stop,
-            use_cache=use_cache,
-            **kwargs,
-        )
-
-        messages.append(self._get_assistant_message(reponse))
-        return reponse, messages, usage
+    ) -> tuple[str, ChatMessages, ModelUsage]:
+        pass
 
     def _get_user_message(self, content: str):
         return {
@@ -193,7 +180,7 @@ class Llama(LLMModel):
         prompt: str,
         *,
         use_cache: bool = False,
-        stop: str = None,
+        stop: str | None = None,
         history: ChatMessages | None = None,
         **kwargs: Any,
     ) -> tuple[str, ModelUsage]:
@@ -274,22 +261,22 @@ class ChatModel:
         prompt: str,
         *,
         use_cache: bool = False,
-        stop: str = None,
+        stop: str | None = None,
         history: ChatMessages | None = None,
         **kwargs: Any,
     ) -> tuple[str, ModelUsage]:
         # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
         # TODO is it better to check for a system message in the history?
-        if history is None and system_message:
-            messages = [self._get_system_message(system_message)]
-        elif history is not None:
-            messages = history
-        else:
-            messages = []
-        messages += [self._get_user_message(prompt)]
+        messages = [self._get_user_message(prompt)]
+
+        if history is None:
+            if system_message:
+                history = [self._get_system_message(system_message)]
+            else:
+                history = []
 
         reponse, usage = self._create_completion(
-            messages=messages,
+            messages=history + messages,
             stop=stop,
             use_cache=use_cache,
             max_tokens=self.options.max_tokens,
@@ -297,7 +284,7 @@ class ChatModel:
         )
 
         messages.append(self._get_assistant_message(reponse))
-        return reponse, messages, usage
+        return reponse, history + messages, usage
 
 
 class LlamaChat(ChatModel, Llama):
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index 648325e..4f91aba 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -1,16 +1,14 @@
 import json
 import logging
-from abc import abstractmethod
-from itertools import zip_longest
 from pathlib import Path
-from typing import Any, Optional
+from typing import Any, Literal, NamedTuple, Optional, TypedDict
 
 from textual.app import App, ComposeResult
 from textual.binding import Binding
-from textual.widgets import Footer, TextArea
+from textual.widgets import Collapsible, Footer, Label, TextArea
 from tqdm import tqdm, trange
 
-from evoprompt.models import LLMModel
+from evoprompt.models import ChatMessages, LLMModel
 from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt
 from evoprompt.task import Task
 from evoprompt.utils import initialize_run_directory, log_calls
@@ -20,19 +18,48 @@ logger = logging.getLogger(__name__)
 
 PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>."""
 
+PromptSource = Literal["baseprompt", "paraphrase", "evolution", "corrected"]
 
-class EditText(App):
+
+class Judgement(NamedTuple):
+    original_response: str
+    corrected_response: str
+    happy: bool
+
+
+class PromptMeta(TypedDict):
+    gen: int
+    source: PromptSource
+    judgements: list[Judgement]
+
+
+class ResponseEditor(App):
     BINDINGS = [
-        Binding(key="q", action="quit", description="Quit the app"),
+        Binding(
+            key="ctrl+q", action="quit", description="Finish Editing & Save Prompt"
+        ),
     ]
 
-    def __init__(self, text):
-        self.text = text + "asdasdasd"
+    def __init__(self, instruction: str, original_response: str, history: ChatMessages):
+        self.instruction = instruction
+        self.response = original_response
+        self.history = history
         super().__init__()
 
     def compose(self) -> ComposeResult:
-        self.text_area = TextArea.code_editor(self.text)
-        yield Footer(self.text_area)
+        self.text_area = TextArea.code_editor(self.response)
+        for idx, message in enumerate(self.history[:-1]):
+            yield Collapsible(
+                Label(message["content"]),
+                title=message["role"],
+                collapsed=idx != len(self.history) - 2,
+            )
+        yield self.text_area
+        yield Footer()
+
+    @property
+    def modified_response(self):
+        return self.text_area.text
 
 
 @log_calls("Paraphrasing prompts")
@@ -133,6 +160,7 @@ class PromptOptimization:
             prompt for _, prompt in sorted_results[: num_initial_prompts // 2]
         ]
         initial_population = top_prompts.copy()
+        prompt_sources = ["baseprompt" for _ in initial_population]
 
         # fill up the rest with paraphrases of the top prompts
         promptindex_to_paraphrase = 0
@@ -146,6 +174,7 @@ class PromptOptimization:
             )
             self.total_evolution_usage += paraphrase_usage
             initial_population += paraphrases
+            prompt_sources.append("paraphrase")
             logger.info(
                 "Paraphrased prompt '%s': %s.",
                 top_prompts[promptindex_to_paraphrase]
@@ -156,13 +185,10 @@ class PromptOptimization:
             promptindex_to_paraphrase += 1
             promptindex_to_paraphrase %= len(top_prompts)
 
-        return initial_population
+        return initial_population, prompt_sources
 
     def add_prompt(
-        self,
-        prompt: str,
-        parents: tuple[Prompt] | None = None,
-        meta: dict | None = None,
+        self, prompt: str, parents: tuple[Prompt] | None, meta: PromptMeta
     ) -> Prompt:
         score, usage, history = self.evaluate_prompt(prompt, parents)
         prompt_object = Prompt(
@@ -188,26 +214,12 @@ class PromptOptimization:
 
         return prompt_object
 
-    def add_prompts(
-        self,
-        prompts: list[str],
-        parents: list[tuple[Prompt]] = iter(()),
-        metas: list[dict] = iter(()),
-    ):
-        return [
-            self.add_prompt(prompt, _parents, meta)
-            for prompt, _parents, meta in zip_longest(prompts, parents, metas)
-        ]
-
     def get_prompt(self, prompt_id: str):
         return self.all_prompts[prompt_id]
 
     def get_prompts(self, prompt_ids: list[str]):
         return [self.get_prompt(p_id) for p_id in prompt_ids]
 
-    @abstractmethod
-    def save_snapshot(self): ...
-
     def init_run(self, num_initial_prompts: int, num_iterations: int):
         # family_tree contains the relation of prompts to its parents
         self.family_tree: dict[str, tuple[str, ...] | None] = {}
@@ -224,16 +236,15 @@ class PromptOptimization:
         self.save_snapshot()
 
         # the initial prompts
-        initial_prompts = self.get_initial_prompts(num_initial_prompts)
-        initial_prompts = self.add_prompts(
-            initial_prompts, metas=[{"gen": 0} for _ in initial_prompts]
-        )
-        # - Initial prompts P0 = {p1, p2, . . . , pN }
-        self.P.append(initial_prompts)
-
-        # accumulate usage
-        for prompt in initial_prompts:
+        initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts)
+        self.P.append([])
+        for prompt, prompt_source in zip(initial_prompts, prompt_sources):
+            prompt = self.add_prompt(
+                prompt, parents=None, meta={"gen": 0, "source": prompt_source}
+            )
             self.total_evaluation_usage += prompt.usage
+            # - Initial prompts P0 = {p1, p2, . . . , pN }
+            self.P[0].append(prompt)
         self.save_snapshot()
 
     def save_snapshot(self):
@@ -264,14 +275,9 @@ class PromptOptimization:
                 cls=OptTypeEncoder,
             )
 
-    def judge_step(
-        self, instruction: str, response: str, history: list[dict[str, str]]
-    ):
-        print("\n" * 20)
-        print(
-            "Instruction:", instruction, "\nResponse:", response, "\nHistory:", history
-        )
-
+    def judge_and_correct_step(
+        self, instruction: str, response: str, history: ChatMessages
+    ) -> Judgement:
         # TODO: judge the actual response
         judge_happy = False
 
@@ -279,14 +285,14 @@ class PromptOptimization:
             f"{self.judge_model.__class__.__name__} judged the response as {'good' if judge_happy else 'bad'}"
         )
         if judge_happy:
-            return True
+            return Judgement(response, response, happy=True)
 
         logger.info(f"Prompt judged as bad. Letting User change the prompt.")
 
-        editor = EditText(response)
+        editor = ResponseEditor(instruction, response, history)
         editor.run()
-        response = editor.text_area.text
-        print(response)
+
+        return Judgement(response, editor.modified_response, happy=False)
 
 
 def load_snapshot(path: Path):
diff --git a/evoprompt/utils.py b/evoprompt/utils.py
index 7cb3e1b..755b796 100644
--- a/evoprompt/utils.py
+++ b/evoprompt/utils.py
@@ -65,7 +65,7 @@ RUNS_DIR = current_directory.parent / "runs"
 RUNS_DIR.mkdir(exist_ok=True)
 
 
-def initialize_run_directory(model: Callable):
+def initialize_run_directory(model):
     global file_handler
 
     # get package logger
-- 
GitLab