diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index 68b262bc67df9f22057a683aad8bd68e5b72bdf4..763a8f3130d8eda97c4a5477c2f173f47387a2c2 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -137,7 +137,12 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): def update(self, *args, **kwargs): pass - def log_iteration(self, iteration: int, prompts: list[Prompt], num_failed_automatic_evolutions: int): + def log_iteration( + self, + iteration: int, + prompts: list[Prompt], + num_failed_automatic_evolutions: int, + ): if wandb.run is not None: best_prompt = max(prompts, key=lambda prompt: prompt.score) prompt_score_avg = sum(p.score for p in prompts) / len(prompts) @@ -248,7 +253,9 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): # If a prompt is None, it means that the prompt was skipped if evolved_prompt is not None: - automatic_prompt_evolution_failed = {Judgement.BAD, Judgement.FAIL} & set(judgements) + automatic_prompt_evolution_failed = bool( + {Judgement.BAD, Judgement.FAIL} & set(judgements) + ) prompt_source = ( "corrected" # could also mean that user skipped the prompt if automatic_prompt_evolution_failed