From fe66ab5a950c6f83d688d6750e4b9b04245e4071 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 3 Dec 2024 21:58:12 +0100
Subject: [PATCH] Log user interactions

---
 evoprompt/evolution/evolution.py | 89 +++++++++++++++++++++++++-------
 evoprompt/optimization.py        | 53 +++++++------------
 2 files changed, 89 insertions(+), 53 deletions(-)

diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index 48c69b6..1ad2a69 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -96,6 +96,13 @@ class EvolutionAnnotationHandler:
             label=label,
             incorrect_label=incorrect_label,
         )
+        logger.info(
+            "Annotated sample (%d remaining): '%s' -> '%s' (incorrect: '%s')",
+            self.max_annotations - len(self),
+            prompt,
+            label,
+            incorrect_label,
+        )
 
     def get_annotations(self, evolution_step: int):
         if self.strategy == "simple":
@@ -257,29 +264,35 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
                 }
             )
 
-    def log_evolution_tries(
+    def log_evolution(
         self,
         *,
         generation: int,
+        update_step: int,
+        total_update_step: int,
         num_bad_judgements: int,
         num_failed_judgements: int,
         num_good_judgements: int,
         final_evolution_step: int,
         num_model_calls: int,
+        num_user_interactions: int,
     ):
         if wandb.run is not None:
             wandb.log(
                 {
                     "generation": generation,
+                    "update_step": update_step,
+                    "total_update_step": total_update_step,
                     "final_evolution_step": final_evolution_step,
                     "num_model_calls": num_model_calls,
                     "num_bad_judgements": num_bad_judgements,
                     "num_failed_judgements": num_failed_judgements,
                     "num_good_judgements": num_good_judgements,
+                    "num_user_interactions": num_user_interactions,
                 }
             )
 
-    def log_evolution_step_tries(
+    def log_evolution_step(
         self,
         *,
         generation: int,
@@ -288,6 +301,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
         num_bad_judgements: int,
         num_failed_judgements: int,
         num_model_calls: int,
+        num_user_interactions: int,
     ):
         if wandb.run is not None:
             wandb.log(
@@ -298,9 +312,26 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
                     "evolution_step_num_bad_judgements": num_bad_judgements,
                     "evolution_step_num_failed_judgements": num_failed_judgements,
                     "evolution_step_num_model_calls": num_model_calls,
+                    "evolution_step_num_user_interactions": num_user_interactions,
                 }
             )
 
+    # also logs generation, update_step and total_update_step for a prompt
+    def log_prompt(
+        self,
+        prompt: Prompt,
+        *,
+        generation: int,
+        update_step: int,
+        total_update_step: int,
+    ):
+        super().log_prompt(
+            prompt,
+            generation=generation,
+            update_step=update_step,
+            total_update_step=total_update_step,
+        )
+
     @weave.op()
     def run(self, num_iterations: int, debug: bool = False) -> None:
         # debug mode for quick run
@@ -537,17 +568,21 @@ class GeneticAlgorithm(EvolutionAlgorithm):
             ),
         )
 
-        self.log_evolution_step_tries(
+        self.log_evolution_step(
             generation=current_generation,
             evolution_step=1,
             did_succeed=judgement in (Judgement.GOOD, Judgement.NONE),
             num_bad_judgements=all_judgements.count(Judgement.BAD),
             num_failed_judgements=all_judgements.count(Judgement.FAIL),
             num_model_calls=len(all_judgements),
+            num_user_interactions=0,
         )
 
-        self.log_evolution_tries(
+        self.log_evolution(
             generation=current_generation,
+            update_step=current_iteration,
+            total_update_step=current_generation * self.population_size
+            + current_iteration,
             num_bad_judgements=all_judgements.count(Judgement.BAD),
             num_failed_judgements=all_judgements.count(Judgement.FAIL),
             num_good_judgements=all_judgements.count(Judgement.GOOD),
@@ -659,17 +694,21 @@ class DifferentialEvolution(EvolutionAlgorithm):
                 evolution_step=1,
             ),
         )
-        self.log_evolution_step_tries(
+        self.log_evolution_step(
             generation=current_generation,
             evolution_step=1,
             did_succeed=judgement in (Judgement.GOOD, Judgement.NONE),
             num_bad_judgements=all_judgements.count(Judgement.BAD),
             num_failed_judgements=all_judgements.count(Judgement.FAIL),
             num_model_calls=len(all_judgements),
+            num_user_interactions=0,
         )
 
-        self.log_evolution_tries(
+        self.log_evolution(
             generation=current_generation,
+            update_step=current_iteration,
+            total_update_step=current_generation * self.population_size
+            + current_iteration,
             num_bad_judgements=all_judgements.count(Judgement.BAD),
             num_failed_judgements=all_judgements.count(Judgement.FAIL),
             num_good_judgements=all_judgements.count(Judgement.GOOD),
@@ -768,6 +807,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
         judgements: list[Judgement] = []
         usage: ModelUsage = ModelUsage()
         evolution_judgements = []
+        num_user_interactions = 0
         for idx, (prompt, demo_messages) in enumerate(
             self._get_prompt_template(), start=1
         ):
@@ -813,13 +853,30 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
             )
             evolution_judgements += all_judgements
 
-            self.log_evolution_step_tries(
+            # If the response was modified, we add an annotation using the corrected response as the label
+            # TODO move to resolve_bad_response in _get_model_response (using callback)?
+            if resolution.corrected_response is not None:
+                self.annotation_handler.add_annotation(
+                    generation=current_generation,
+                    update_step=current_iteration,
+                    evolution_step=idx,
+                    prompt=prompt,
+                    label=resolution.corrected_response,
+                    incorrect_label=resolution.original_response,
+                )
+                user_interaction = True
+                num_user_interactions += 1
+            else:
+                user_interaction = False
+
+            self.log_evolution_step(
                 generation=current_generation,
                 evolution_step=idx,
                 did_succeed=judgement in (Judgement.GOOD, Judgement.NONE),
                 num_bad_judgements=all_judgements.count(Judgement.BAD),
                 num_failed_judgements=all_judgements.count(Judgement.FAIL),
                 num_model_calls=len(all_judgements),
+                num_user_interactions=int(user_interaction),
             )
 
             response = (
@@ -828,18 +885,6 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
                 else resolution.original_response
             )
 
-            # If the response was modified, we add an annotation using the corrected response as the label
-            # TODO move to resolve_bad_response in _get_model_response (using callback)?
-            if resolution.corrected_response is not None:
-                self.annotation_handler.add_annotation(
-                    generation=current_generation,
-                    update_step=current_iteration,
-                    evolution_step=idx,
-                    prompt=prompt,
-                    label=resolution.corrected_response,
-                    incorrect_label=resolution.original_response,
-                )
-
             logger.debug(
                 "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s",
                 idx,
@@ -854,13 +899,17 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
             evolutions_steps.append(
                 self.evolution_model._get_assistant_message(response)
             )
-        self.log_evolution_tries(
+        self.log_evolution(
             generation=current_generation,
+            update_step=current_iteration,
+            total_update_step=current_generation * self.population_size
+            + current_iteration,
             num_bad_judgements=evolution_judgements.count(Judgement.BAD),
             num_failed_judgements=evolution_judgements.count(Judgement.FAIL),
             num_good_judgements=evolution_judgements.count(Judgement.GOOD),
             final_evolution_step=idx,
             num_model_calls=len(evolution_judgements),
+            num_user_interactions=num_user_interactions,
         )
 
         evolved_prompt = self.parse_response(response)
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index 6b77587..593bfe8 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -190,13 +190,6 @@ class AnnotationHandler:
         self.annotated_samples[identifier].append(
             AnnotatedSample(prompt, label, incorrect_label)
         )
-        logger.info(
-            "Annotated sample (%d remaining): '%s' -> '%s' (incorrect: '%s')",
-            self.max_annotations - len(self),
-            prompt,
-            label,
-            incorrect_label,
-        )
 
     def get_annotations(self) -> list[AnnotatedSample]:
         if self.strategy == "simple":
@@ -297,33 +290,27 @@ class PromptOptimization:
     def log_prompt(
         self,
         prompt: Prompt,
-        *,
-        generation: int,
-        update_step: int,
-        total_update_step: int,
+        **kwargs,
     ):
         if wandb.run is not None:
-            wandb.log(
-                {
-                    "prompts": wandb.Html(
-                        tabulate(
-                            (
-                                (prompt.content, prompt.score, prompt.meta["gen"])
-                                for prompt in self.all_prompts.values()
-                            ),
-                            tablefmt="html",
-                            headers=["Prompt", "Score", "Gen"],
-                            showindex=True,
-                        )
-                    ),
-                    # having different metrics makes it difficult to plot the data, so we just log the score under the metric name and under a generic name
-                    f"validation/{self.task.metric_name}": prompt.score,
-                    "validation/score": prompt.score,
-                    "generation": generation,
-                    "update_step": update_step,
-                    "total_update_step": total_update_step,
-                },
-            )
+            dict_to_log = {
+                "prompts": wandb.Html(
+                    tabulate(
+                        (
+                            (prompt.content, prompt.score, prompt.meta["gen"])
+                            for prompt in self.all_prompts.values()
+                        ),
+                        tablefmt="html",
+                        headers=["Prompt", "Score", "Gen"],
+                        showindex=True,
+                    )
+                ),
+                # having different metrics makes it difficult to plot the data, so we just log the score under the metric name and under a generic name
+                f"validation/{self.task.metric_name}": prompt.score,
+                "validation/score": prompt.score,
+            }
+            dict_to_log.update(kwargs)
+            wandb.log(dict_to_log)
 
     def add_prompt(
         self, prompt: str, parents: tuple[Prompt] | None, meta: PromptMeta
@@ -507,7 +494,7 @@ class PromptOptimization:
                 ), "Annotation Check Function required."
                 if not fn_check_for_annotation():
                     logger.info(
-                        f"Prompt judged as bad, but annotation limit reached. Will not ask for correction."
+                        f"Asked for correction, but annotation limit reached. Will not ask for correction."
                     )
                     return JudgeResolution(response, None, skip=False)
                 # let user skip or correct the response in an interactive way
-- 
GitLab