Skip to content
Snippets Groups Projects
Commit 13e3d550 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

add framework to retry evolution with human feedback if performance did not improve

parent bc7445bc
No related branches found
No related tags found
1 merge request!12User interaction only after non improvement
......@@ -556,79 +556,100 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
with weave.attributes({"run_name": self.run_name}):
# Algorithm 1 Discrete prompt optimization: EVOPROMPT
# Line 2:
for t in self.iterations_pbar:
# Line 3: Selection: select a certain number of prompts from current population as parent prompts
# pr1,...,prk ∼ Pt−1
prompts_current_evolution = self.P[t - 1]
new_evolutions = []
num_failed_automatic_evolutions = 0
for i in trange(
self.population_size,
desc="updates",
leave=False,
disable=None,
# if the iteration does not improve over the previous generation, we rerun the iteration with human feedback
# TODO: what to do if the iteration does not improve over the previous generation even with human feedback?
# this design allows us to introduce a `max_human_feedback_iterations` parameter if needed
iteration_try_for_generation = 0
iteration_has_improved = False
while (
not iteration_has_improved
and iteration_try_for_generation < 2
):
# for both GA and DE we start with two parent prompts
pr1, pr2 = self.select(self.P[t - 1])
prompts_current_evolution = self.P[t - 1]
new_evolutions = []
num_failed_automatic_evolutions = 0
for i in trange(
self.population_size,
desc="updates",
leave=False,
disable=None,
):
# for both GA and DE we start with two parent prompts
pr1, pr2 = self.select(self.P[t - 1])
# Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operation
# p′i ←Evo(pr1,...,prk)
(
evolved_prompt,
judgements,
evolution_usage,
) = self.evolve(
pr1,
pr2,
prompts_current_evolution=prompts_current_evolution,
current_iteration=i,
current_generation=t,
)
self.total_evolution_usage += evolution_usage
# Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operation
# p′i ←Evo(pr1,...,prk)
(
evolved_prompt,
judgements,
evolution_usage,
) = self.evolve(
pr1,
pr2,
automatic_prompt_evolution_failed = bool(
{Judgement.BAD, Judgement.FAIL} & set(judgements)
)
if automatic_prompt_evolution_failed:
num_failed_automatic_evolutions += 1
# If a prompt is None, it means that the prompt was skipped
if evolved_prompt is not None:
prompt_source = (
"corrected" # could also mean that user skipped the prompt
if automatic_prompt_evolution_failed
else "evolution"
)
evolved_prompt = self.add_prompt(
evolved_prompt,
parents=(pr1, pr2),
meta={
"gen": t,
"source": prompt_source,
"judgements": judgements,
},
)
self.total_evaluation_usage += evolved_prompt.usage
self.log_prompt(
evolved_prompt,
generation=t,
update_step=i,
total_update_step=t * self.population_size + i,
)
new_evolutions.append(evolved_prompt)
else:
new_evolutions.append(None)
self.save_snapshot()
# Line 6: Update based on the evaluation scores
# Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
new_population = self.update(
new_evolutions=new_evolutions,
prompts_current_evolution=prompts_current_evolution,
current_iteration=i,
current_generation=t,
)
self.total_evolution_usage += evolution_usage
automatic_prompt_evolution_failed = bool(
{Judgement.BAD, Judgement.FAIL} & set(judgements)
best_score = lambda population: max(
population, key=lambda prompt: prompt.score
)
if automatic_prompt_evolution_failed:
num_failed_automatic_evolutions += 1
# If a prompt is None, it means that the prompt was skipped
if evolved_prompt is not None:
prompt_source = (
"corrected" # could also mean that user skipped the prompt
if automatic_prompt_evolution_failed
else "evolution"
)
evolved_prompt = self.add_prompt(
evolved_prompt,
parents=(pr1, pr2),
meta={
"gen": t,
"source": prompt_source,
"judgements": judgements,
},
iteration_has_improved = best_score(
new_population
) > best_score(prompts_current_evolution)
if not iteration_has_improved:
logger.info(
"Iteration %d did not improve over previous generation. Rerunning with human feedback",
t,
)
self.total_evaluation_usage += evolved_prompt.usage
self.log_prompt(
evolved_prompt,
generation=t,
update_step=i,
total_update_step=t * self.population_size + i,
)
new_evolutions.append(evolved_prompt)
else:
new_evolutions.append(None)
self.save_snapshot()
# Line 6: Update based on the evaluation scores
# Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
new_population = self.update(
new_evolutions=new_evolutions,
prompts_current_evolution=prompts_current_evolution,
)
# log metrics to wandb
self.log_iteration(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment