From 64a960e687f168d58b1f755d6738e9e145c85923 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Wed, 23 Oct 2024 12:53:58 +0200
Subject: [PATCH] Skipping prompts will no longer result in logging old prompts

---
 evoprompt/evolution/evolution.py | 21 ++++++++++++++-------
 evoprompt/task/task.py           |  2 +-
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index 5204ef1..ac4d97c 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -214,7 +214,9 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
                         self.log_prompt(evolved_prompt, t, i)
 
                         new_evolutions.append(evolved_prompt)
-                        self.save_snapshot()
+                    else:
+                        new_evolutions.append(None)
+                    self.save_snapshot()
                 # Line 6: Update based on the evaluation scores
                 # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
                 new_population = self.update(new_evolutions, prompts_current_evolution)
@@ -337,6 +339,8 @@ class GeneticAlgorithm(EvolutionAlgorithm):
         retained_prompts: list[Prompt] = []
         min_retained_score = 0
         for prompt in prompts_current_evolution + new_evolutions:
+            if prompt is None:
+                continue
             if len(retained_prompts) < self.population_size:
                 retained_prompts.append(prompt)
                 min_retained_score = min(min_retained_score, prompt.score)
@@ -415,9 +419,9 @@ class DifferentialEvolution(EvolutionAlgorithm):
         )
 
         if judgement.skip:
-            # user asked to skip this prompt, for DE this means using the basic prompt
+            # user asked to skip this prompt
             return (
-                prompts_current_evolution[current_iteration].content,
+                None,
                 [judgement],
                 usage,
             )
@@ -428,9 +432,9 @@ class DifferentialEvolution(EvolutionAlgorithm):
         if evolved_prompt is None:
             logger.info(f"Could not extract prompt from response: {evolved_prompt}")
 
-            # no prompt was returned (e.g., evolved prompt could not be extracted), therefore, for DE, we use the basic prompt
+            # no prompt was returned (e.g., evolved prompt could not be extracted)
             return (
-                prompts_current_evolution[current_iteration].content,
+                None,
                 [judgement],
                 usage,
             )
@@ -453,7 +457,11 @@ class DifferentialEvolution(EvolutionAlgorithm):
         # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
         assert len(prompts_current_evolution) == len(new_evolutions)
         population = [
-            (new_prompt if new_prompt.score > current_prompt.score else current_prompt)
+            (
+                new_prompt
+                if new_prompt is not None and new_prompt.score > current_prompt.score
+                else current_prompt
+            )
             for current_prompt, new_prompt in zip(
                 prompts_current_evolution, new_evolutions
             )
@@ -553,7 +561,6 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
                 response,
             )
             # TODO use serialized messages as prompt or use previous evolution steps as history?
-            input(f"{len(evolutions_steps)}, \n{evolutions_steps}")
             judgement = self.judge_and_correct_step(
                 filled_prompt,
                 response,
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 90f128b..e322998 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -411,7 +411,7 @@ class Task(metaclass=ABCMeta):
         system_message, prompt_for_datum = self.build_prompt_input(
             datum, instruction, use_prediction_prefix=self.force_task_prediction_prefix
         )
-        logger.debug(f"Prompt for datum:\n{prompt_for_datum}")
+        logger.debug(f"==== Task Prediction ====\nSystem message:\n\t{system_message}\nPrompt for datum:\n\t{prompt_for_datum}\nHistory:\n\t{history}")
         response, _, _, usage = self.model.create_completion(
             system_message=system_message,
             messages=prompt_for_datum,
-- 
GitLab