From e751763b46b133c407097471b0ab636e79778fbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Tue, 20 Aug 2024 08:42:28 +0200 Subject: [PATCH] add interactive editor to correct badly judged reponses --- evoprompt/api/backend.py | 1 - evoprompt/evolution.py | 51 +++++++++++++----- evoprompt/models.py | 41 +++++---------- evoprompt/optimization.py | 108 ++++++++++++++++++++------------------ evoprompt/utils.py | 2 +- 5 files changed, 109 insertions(+), 94 deletions(-) diff --git a/evoprompt/api/backend.py b/evoprompt/api/backend.py index 608d26b..2c9e0eb 100644 --- a/evoprompt/api/backend.py +++ b/evoprompt/api/backend.py @@ -74,7 +74,6 @@ class MultiProcessOptimizer: self.model_exec = None def _setup_models(self) -> None: - print("setup models") if self._evolution_model is not None: raise Exception("Evolution model has already been initialized.") diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index ec67579..2b4e6c7 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -8,7 +8,7 @@ from tqdm import trange from evoprompt.cli import argument_parser from evoprompt.models import LLMModel from evoprompt.opt_types import ModelUsage, Prompt -from evoprompt.optimization import PromptOptimization +from evoprompt.optimization import Judgement, PromptOptimization from evoprompt.task import Task from evoprompt.template_de import get_de_prompt_template from evoprompt.utils import get_all_subclasses, get_rng, log_calls @@ -85,7 +85,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): *, prompts_current_evolution: list[Prompt], current_iteration: int, - ): + ) -> tuple[str, list[Judgement], ModelUsage]: pass @abstractmethod @@ -116,7 +116,11 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators # p′i â†Evo(pr1,...,prk) - p_i, evolution_usage = self.evolve( + ( + p_i, + judgements, + evolution_usage, + ) = self.evolve( pr1, pr2, prompts_current_evolution=prompts_current_evolution, @@ -124,7 +128,14 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): ) self.total_evolution_usage += evolution_usage - evolved_prompt = self.add_prompt(p_i, (pr1, pr2), {"gen": t}) + prompt_source = ( + "corrected" if not all(j.happy for j in judgements) else "generated" + ) + evolved_prompt = self.add_prompt( + p_i, + parents=(pr1, pr2), + meta={"gen": t, "source": prompt_source, "judgements": judgements}, + ) self.total_evaluation_usage += evolved_prompt.usage new_evolutions.append(evolved_prompt) @@ -171,7 +182,7 @@ class GeneticAlgorithm(EvolutionAlgorithm): prompt_1: str, prompt_2: str, **kwargs, - ) -> tuple[str, ModelUsage]: + ): # Following the evolutionary operators in GA, a new candidate prompt is generated through # a two-step process based on the selected two parents: # 1) The parent prompts undergo crossover, resulting in a new prompt that @@ -186,6 +197,10 @@ class GeneticAlgorithm(EvolutionAlgorithm): system_message=SYSTEM_MESSAGE, prompt=filled_prompt, ) + + judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages) + evolved_prompt = judgement.corrected_response + if "<prompt>" in evolved_prompt: evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] @@ -195,9 +210,8 @@ class GeneticAlgorithm(EvolutionAlgorithm): prompt_2, evolved_prompt, ) - self.judge_step(filled_prompt, evolved_prompt, messages) - return evolved_prompt, usage + return evolved_prompt, [judgement], usage @log_calls("Performing update for GA") def update( @@ -255,6 +269,10 @@ class DifferentialEvolution(EvolutionAlgorithm): system_message=SYSTEM_MESSAGE, prompt=filled_prompt, ) + + judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages) + evolved_prompt = judgement.corrected_response + matches = re.findall( # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))", @@ -281,9 +299,7 @@ class DifferentialEvolution(EvolutionAlgorithm): evolved_prompt, ) - self.judge_step(filled_prompt, evolved_prompt, messages) - - return evolved_prompt, usage + return evolved_prompt, [judgement], usage @log_calls("Performing update for DE") def update( @@ -322,6 +338,9 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): ) messages = None + response: str = "" + judgements: list[Judgement] = [] + usage: ModelUsage = ModelUsage() for idx, prompt in enumerate(DE_COT_PROMPTS): filled_prompt = prompt.format( prompt1=prompt_1, @@ -341,9 +360,13 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): messages, response, ) - self.judge_step(filled_prompt, response, history=messages) - # input(messages) - # input(response) + judgement = self.judge_and_correct_step( + filled_prompt, response, history=messages + ) + judgements.append(judgement) + # replace last message with corrected response + messages[-1]["content"] = judgement.corrected_response + response = judgement.corrected_response # at this point we should get a new prompt if "<prompt>" in response: @@ -358,7 +381,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): response, ) - return response, usage + return response, judgements, usage def get_all_subclasses(cls): diff --git a/evoprompt/models.py b/evoprompt/models.py index b2877e7..a10005d 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -63,24 +63,11 @@ class LLMModel(ABC): prompt: str, *, use_cache: bool = False, - stop: str = None, + stop: str | None = None, history: ChatMessages | None = None, **kwargs: Any, - ) -> tuple[str, list[dict[str, str]], ModelUsage]: - # create prompt - prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix - messages = [self._get_user_message(prompt)] - model_input = self.build_model_input(prompt, system_message, messages, history) - - reponse, usage = self._create_completion( - **model_input, - stop=stop, - use_cache=use_cache, - **kwargs, - ) - - messages.append(self._get_assistant_message(reponse)) - return reponse, messages, usage + ) -> tuple[str, ChatMessages, ModelUsage]: + pass def _get_user_message(self, content: str): return { @@ -193,7 +180,7 @@ class Llama(LLMModel): prompt: str, *, use_cache: bool = False, - stop: str = None, + stop: str | None = None, history: ChatMessages | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: @@ -274,22 +261,22 @@ class ChatModel: prompt: str, *, use_cache: bool = False, - stop: str = None, + stop: str | None = None, history: ChatMessages | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case # TODO is it better to check for a system message in the history? - if history is None and system_message: - messages = [self._get_system_message(system_message)] - elif history is not None: - messages = history - else: - messages = [] - messages += [self._get_user_message(prompt)] + messages = [self._get_user_message(prompt)] + + if history is None: + if system_message: + history = [self._get_system_message(system_message)] + else: + history = [] reponse, usage = self._create_completion( - messages=messages, + messages=history + messages, stop=stop, use_cache=use_cache, max_tokens=self.options.max_tokens, @@ -297,7 +284,7 @@ class ChatModel: ) messages.append(self._get_assistant_message(reponse)) - return reponse, messages, usage + return reponse, history + messages, usage class LlamaChat(ChatModel, Llama): diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 648325e..4f91aba 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -1,16 +1,14 @@ import json import logging -from abc import abstractmethod -from itertools import zip_longest from pathlib import Path -from typing import Any, Optional +from typing import Any, Literal, NamedTuple, Optional, TypedDict from textual.app import App, ComposeResult from textual.binding import Binding -from textual.widgets import Footer, TextArea +from textual.widgets import Collapsible, Footer, Label, TextArea from tqdm import tqdm, trange -from evoprompt.models import LLMModel +from evoprompt.models import ChatMessages, LLMModel from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt from evoprompt.task import Task from evoprompt.utils import initialize_run_directory, log_calls @@ -20,19 +18,48 @@ logger = logging.getLogger(__name__) PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.""" +PromptSource = Literal["baseprompt", "paraphrase", "evolution", "corrected"] -class EditText(App): + +class Judgement(NamedTuple): + original_response: str + corrected_response: str + happy: bool + + +class PromptMeta(TypedDict): + gen: int + source: PromptSource + judgements: list[Judgement] + + +class ResponseEditor(App): BINDINGS = [ - Binding(key="q", action="quit", description="Quit the app"), + Binding( + key="ctrl+q", action="quit", description="Finish Editing & Save Prompt" + ), ] - def __init__(self, text): - self.text = text + "asdasdasd" + def __init__(self, instruction: str, original_response: str, history: ChatMessages): + self.instruction = instruction + self.response = original_response + self.history = history super().__init__() def compose(self) -> ComposeResult: - self.text_area = TextArea.code_editor(self.text) - yield Footer(self.text_area) + self.text_area = TextArea.code_editor(self.response) + for idx, message in enumerate(self.history[:-1]): + yield Collapsible( + Label(message["content"]), + title=message["role"], + collapsed=idx != len(self.history) - 2, + ) + yield self.text_area + yield Footer() + + @property + def modified_response(self): + return self.text_area.text @log_calls("Paraphrasing prompts") @@ -133,6 +160,7 @@ class PromptOptimization: prompt for _, prompt in sorted_results[: num_initial_prompts // 2] ] initial_population = top_prompts.copy() + prompt_sources = ["baseprompt" for _ in initial_population] # fill up the rest with paraphrases of the top prompts promptindex_to_paraphrase = 0 @@ -146,6 +174,7 @@ class PromptOptimization: ) self.total_evolution_usage += paraphrase_usage initial_population += paraphrases + prompt_sources.append("paraphrase") logger.info( "Paraphrased prompt '%s': %s.", top_prompts[promptindex_to_paraphrase] @@ -156,13 +185,10 @@ class PromptOptimization: promptindex_to_paraphrase += 1 promptindex_to_paraphrase %= len(top_prompts) - return initial_population + return initial_population, prompt_sources def add_prompt( - self, - prompt: str, - parents: tuple[Prompt] | None = None, - meta: dict | None = None, + self, prompt: str, parents: tuple[Prompt] | None, meta: PromptMeta ) -> Prompt: score, usage, history = self.evaluate_prompt(prompt, parents) prompt_object = Prompt( @@ -188,26 +214,12 @@ class PromptOptimization: return prompt_object - def add_prompts( - self, - prompts: list[str], - parents: list[tuple[Prompt]] = iter(()), - metas: list[dict] = iter(()), - ): - return [ - self.add_prompt(prompt, _parents, meta) - for prompt, _parents, meta in zip_longest(prompts, parents, metas) - ] - def get_prompt(self, prompt_id: str): return self.all_prompts[prompt_id] def get_prompts(self, prompt_ids: list[str]): return [self.get_prompt(p_id) for p_id in prompt_ids] - @abstractmethod - def save_snapshot(self): ... - def init_run(self, num_initial_prompts: int, num_iterations: int): # family_tree contains the relation of prompts to its parents self.family_tree: dict[str, tuple[str, ...] | None] = {} @@ -224,16 +236,15 @@ class PromptOptimization: self.save_snapshot() # the initial prompts - initial_prompts = self.get_initial_prompts(num_initial_prompts) - initial_prompts = self.add_prompts( - initial_prompts, metas=[{"gen": 0} for _ in initial_prompts] - ) - # - Initial prompts P0 = {p1, p2, . . . , pN } - self.P.append(initial_prompts) - - # accumulate usage - for prompt in initial_prompts: + initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts) + self.P.append([]) + for prompt, prompt_source in zip(initial_prompts, prompt_sources): + prompt = self.add_prompt( + prompt, parents=None, meta={"gen": 0, "source": prompt_source} + ) self.total_evaluation_usage += prompt.usage + # - Initial prompts P0 = {p1, p2, . . . , pN } + self.P[0].append(prompt) self.save_snapshot() def save_snapshot(self): @@ -264,14 +275,9 @@ class PromptOptimization: cls=OptTypeEncoder, ) - def judge_step( - self, instruction: str, response: str, history: list[dict[str, str]] - ): - print("\n" * 20) - print( - "Instruction:", instruction, "\nResponse:", response, "\nHistory:", history - ) - + def judge_and_correct_step( + self, instruction: str, response: str, history: ChatMessages + ) -> Judgement: # TODO: judge the actual response judge_happy = False @@ -279,14 +285,14 @@ class PromptOptimization: f"{self.judge_model.__class__.__name__} judged the response as {'good' if judge_happy else 'bad'}" ) if judge_happy: - return True + return Judgement(response, response, happy=True) logger.info(f"Prompt judged as bad. Letting User change the prompt.") - editor = EditText(response) + editor = ResponseEditor(instruction, response, history) editor.run() - response = editor.text_area.text - print(response) + + return Judgement(response, editor.modified_response, happy=False) def load_snapshot(path: Path): diff --git a/evoprompt/utils.py b/evoprompt/utils.py index 7cb3e1b..755b796 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -65,7 +65,7 @@ RUNS_DIR = current_directory.parent / "runs" RUNS_DIR.mkdir(exist_ok=True) -def initialize_run_directory(model: Callable): +def initialize_run_directory(model): global file_handler # get package logger -- GitLab