From 13e3d550c030344ae972722f4c6196f3104c6949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Thu, 2 Jan 2025 10:31:22 +0100 Subject: [PATCH] add framework to retry evolution with human feedback if performance did not improve --- evoprompt/evolution/evolution.py | 143 ++++++++++++++++++------------- 1 file changed, 82 insertions(+), 61 deletions(-) diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index 76a8275..627b1ea 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -556,79 +556,100 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): with weave.attributes({"run_name": self.run_name}): # Algorithm 1 Discrete prompt optimization: EVOPROMPT - # Line 2: for t in self.iterations_pbar: # Line 3: Selection: select a certain number of prompts from current population as parent prompts # pr1,...,prk ∼ Pt−1 - prompts_current_evolution = self.P[t - 1] - - new_evolutions = [] - num_failed_automatic_evolutions = 0 - for i in trange( - self.population_size, - desc="updates", - leave=False, - disable=None, + # if the iteration does not improve over the previous generation, we rerun the iteration with human feedback + # TODO: what to do if the iteration does not improve over the previous generation even with human feedback? + # this design allows us to introduce a `max_human_feedback_iterations` parameter if needed + iteration_try_for_generation = 0 + iteration_has_improved = False + while ( + not iteration_has_improved + and iteration_try_for_generation < 2 ): - # for both GA and DE we start with two parent prompts - pr1, pr2 = self.select(self.P[t - 1]) + prompts_current_evolution = self.P[t - 1] + + new_evolutions = [] + num_failed_automatic_evolutions = 0 + for i in trange( + self.population_size, + desc="updates", + leave=False, + disable=None, + ): + # for both GA and DE we start with two parent prompts + pr1, pr2 = self.select(self.P[t - 1]) + + # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operation + # p′i â†Evo(pr1,...,prk) + ( + evolved_prompt, + judgements, + evolution_usage, + ) = self.evolve( + pr1, + pr2, + prompts_current_evolution=prompts_current_evolution, + current_iteration=i, + current_generation=t, + ) + self.total_evolution_usage += evolution_usage - # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operation - # p′i â†Evo(pr1,...,prk) - ( - evolved_prompt, - judgements, - evolution_usage, - ) = self.evolve( - pr1, - pr2, + automatic_prompt_evolution_failed = bool( + {Judgement.BAD, Judgement.FAIL} & set(judgements) + ) + if automatic_prompt_evolution_failed: + num_failed_automatic_evolutions += 1 + # If a prompt is None, it means that the prompt was skipped + if evolved_prompt is not None: + prompt_source = ( + "corrected" # could also mean that user skipped the prompt + if automatic_prompt_evolution_failed + else "evolution" + ) + + evolved_prompt = self.add_prompt( + evolved_prompt, + parents=(pr1, pr2), + meta={ + "gen": t, + "source": prompt_source, + "judgements": judgements, + }, + ) + self.total_evaluation_usage += evolved_prompt.usage + self.log_prompt( + evolved_prompt, + generation=t, + update_step=i, + total_update_step=t * self.population_size + i, + ) + + new_evolutions.append(evolved_prompt) + else: + new_evolutions.append(None) + self.save_snapshot() + # Line 6: Update based on the evaluation scores + # Pt ↠{Pt−1, p′i} and St ↠{St−1, s′i} + new_population = self.update( + new_evolutions=new_evolutions, prompts_current_evolution=prompts_current_evolution, - current_iteration=i, - current_generation=t, ) - self.total_evolution_usage += evolution_usage - automatic_prompt_evolution_failed = bool( - {Judgement.BAD, Judgement.FAIL} & set(judgements) + best_score = lambda population: max( + population, key=lambda prompt: prompt.score ) - if automatic_prompt_evolution_failed: - num_failed_automatic_evolutions += 1 - # If a prompt is None, it means that the prompt was skipped - if evolved_prompt is not None: - prompt_source = ( - "corrected" # could also mean that user skipped the prompt - if automatic_prompt_evolution_failed - else "evolution" - ) - evolved_prompt = self.add_prompt( - evolved_prompt, - parents=(pr1, pr2), - meta={ - "gen": t, - "source": prompt_source, - "judgements": judgements, - }, + iteration_has_improved = best_score( + new_population + ) > best_score(prompts_current_evolution) + if not iteration_has_improved: + logger.info( + "Iteration %d did not improve over previous generation. Rerunning with human feedback", + t, ) - self.total_evaluation_usage += evolved_prompt.usage - self.log_prompt( - evolved_prompt, - generation=t, - update_step=i, - total_update_step=t * self.population_size + i, - ) - - new_evolutions.append(evolved_prompt) - else: - new_evolutions.append(None) - self.save_snapshot() - # Line 6: Update based on the evaluation scores - # Pt ↠{Pt−1, p′i} and St ↠{St−1, s′i} - new_population = self.update( - new_evolutions=new_evolutions, - prompts_current_evolution=prompts_current_evolution, - ) # log metrics to wandb self.log_iteration( -- GitLab