From f021dfcb1444286e3c57626d1082d8f5a427fc4b Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Fri, 6 Sep 2024 18:38:54 +0200 Subject: [PATCH] Allow user to skip evolution if response was judged as bad --- evoprompt/evolution/evolution.py | 67 +++++++++++++++++++++++--------- evoprompt/optimization.py | 29 +++++++++++--- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index dbf8079..0ec9785 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -159,25 +159,27 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): ) self.total_evolution_usage += evolution_usage - prompt_source = ( - "corrected" - if not all(j.happy for j in judgements) - else "generated" - ) - evolved_prompt = self.add_prompt( - p_i, - parents=(pr1, pr2), - meta={ - "gen": t, - "source": prompt_source, - "judgements": judgements, - }, - ) - self.total_evaluation_usage += evolved_prompt.usage - self.log_prompt(evolved_prompt, t, i) - - new_evolutions.append(evolved_prompt) - self.save_snapshot() + # If a prompt is None, it means that the prompt was skipped + if p_i is not None: + prompt_source = ( + "corrected" # could also mean that user skipped the prompt + if not all(j.happy for j in judgements) + else "generated" + ) + evolved_prompt = self.add_prompt( + p_i, + parents=(pr1, pr2), + meta={ + "gen": t, + "source": prompt_source, + "judgements": judgements, + }, + ) + self.total_evaluation_usage += evolved_prompt.usage + self.log_prompt(evolved_prompt, t, i) + + new_evolutions.append(evolved_prompt) + self.save_snapshot() # Line 6: Update based on the evaluation scores # Pt ↠{Pt−1, p′i} and St ↠{St−1, s′i} new_population = self.update(new_evolutions, prompts_current_evolution) @@ -249,6 +251,15 @@ class GeneticAlgorithm(EvolutionAlgorithm): ) judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages) + + if judgement.skip: + # skip this prompt, for GA this is ok since during the update step we consider all prompts keeping the population size constant + return ( + None, + [judgement], + usage, + ) + evolved_prompt = judgement.corrected_response if "<prompt>" in evolved_prompt: @@ -338,6 +349,15 @@ class DifferentialEvolution(EvolutionAlgorithm): ) judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages) + + if judgement.skip: + # skip this prompt, for DE this means using the basic prompt + return ( + prompts_current_evolution[current_iteration].content, + [judgement], + usage, + ) + evolved_prompt = judgement.corrected_response matches = re.findall( @@ -450,6 +470,15 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): filled_prompt, response, history=messages ) judgements.append(judgement) + + if judgement.skip: + # skip this prompt, for DE this means using the basic prompt + return ( + prompts_current_evolution[current_iteration].content, + judgements, + usage, + ) + # replace last message with corrected response messages[-1]["content"] = judgement.corrected_response response = judgement.corrected_response diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index e1aad33..c1bf378 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -32,6 +32,7 @@ class Judgement(NamedTuple): original_response: str corrected_response: str happy: bool | None + skip: bool class PromptMeta(TypedDict): @@ -43,8 +44,9 @@ class PromptMeta(TypedDict): class ResponseEditor(App): BINDINGS = [ Binding( - key="ctrl+q", action="quit", description="Finish Editing & Save Prompt" + key="ctrl+s", action="quit", description="Finish Editing & Save Prompt" ), + Binding(key="ctrl+n", action="skip", description="Skip Prompt"), ] def __init__( @@ -58,6 +60,7 @@ class ResponseEditor(App): self.response = original_response self.history = history self.judge_response = judge_response + self.skip = False # used to mark the prompt as skipped super().__init__() def compose(self) -> ComposeResult: @@ -79,6 +82,10 @@ class ResponseEditor(App): ) yield Footer() + async def action_skip(self): + self.skip = True + await self.action_quit() + @property def modified_response(self): return self.text_area.text @@ -337,7 +344,7 @@ class PromptOptimization: self, instruction: str, response: str, history: ChatMessages ) -> Judgement: if self.judge_model is None: - return Judgement(response, response, happy=None) + return Judgement(response, response, happy=None, skip=False) # judge the actual response prompt = f"Instruction: {instruction}\nResponse: {response}" @@ -369,10 +376,11 @@ class PromptOptimization: ) if judge_happy: - return Judgement(response, response, happy=True) + return Judgement(response, response, happy=True, skip=False) - logger.info(f"Prompt judged as bad. Letting User change the prompt.") + logger.info(f"Prompt judged as bad, letting user take action.") + # let user skip or correct the response in an interactive way editor = ResponseEditor( instruction, response, @@ -381,7 +389,18 @@ class PromptOptimization: ) editor.run() - return Judgement(response, editor.modified_response, happy=False) + if editor.skip: + logger.info("User skipped prompt.") + else: + logger.info( + "User corrected prompt:\n'%s'\n -> \n'%s'", + response, + editor.modified_response, + ) + + return Judgement( + response, editor.modified_response, happy=False, skip=editor.skip + ) def load_snapshot(path: Path): -- GitLab