From 44bb59b461c565d0d9e63a5b53fb0a04ddda3a7c Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Mon, 18 Nov 2024 11:07:08 +0100
Subject: [PATCH] Improve code readability

---
 evoprompt/evolution/evolution.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index 68b262b..763a8f3 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -137,7 +137,12 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
     def update(self, *args, **kwargs):
         pass
 
-    def log_iteration(self, iteration: int, prompts: list[Prompt], num_failed_automatic_evolutions: int):
+    def log_iteration(
+        self,
+        iteration: int,
+        prompts: list[Prompt],
+        num_failed_automatic_evolutions: int,
+    ):
         if wandb.run is not None:
             best_prompt = max(prompts, key=lambda prompt: prompt.score)
             prompt_score_avg = sum(p.score for p in prompts) / len(prompts)
@@ -248,7 +253,9 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
 
                     # If a prompt is None, it means that the prompt was skipped
                     if evolved_prompt is not None:
-                        automatic_prompt_evolution_failed = {Judgement.BAD, Judgement.FAIL} & set(judgements)
+                        automatic_prompt_evolution_failed = bool(
+                            {Judgement.BAD, Judgement.FAIL} & set(judgements)
+                        )
                         prompt_source = (
                             "corrected"  # could also mean that user skipped the prompt
                             if automatic_prompt_evolution_failed
-- 
GitLab