From fe66ab5a950c6f83d688d6750e4b9b04245e4071 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 3 Dec 2024 21:58:12 +0100 Subject: [PATCH] Log user interactions --- evoprompt/evolution/evolution.py | 89 +++++++++++++++++++++++++------- evoprompt/optimization.py | 53 +++++++------------ 2 files changed, 89 insertions(+), 53 deletions(-) diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index 48c69b6..1ad2a69 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -96,6 +96,13 @@ class EvolutionAnnotationHandler: label=label, incorrect_label=incorrect_label, ) + logger.info( + "Annotated sample (%d remaining): '%s' -> '%s' (incorrect: '%s')", + self.max_annotations - len(self), + prompt, + label, + incorrect_label, + ) def get_annotations(self, evolution_step: int): if self.strategy == "simple": @@ -257,29 +264,35 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): } ) - def log_evolution_tries( + def log_evolution( self, *, generation: int, + update_step: int, + total_update_step: int, num_bad_judgements: int, num_failed_judgements: int, num_good_judgements: int, final_evolution_step: int, num_model_calls: int, + num_user_interactions: int, ): if wandb.run is not None: wandb.log( { "generation": generation, + "update_step": update_step, + "total_update_step": total_update_step, "final_evolution_step": final_evolution_step, "num_model_calls": num_model_calls, "num_bad_judgements": num_bad_judgements, "num_failed_judgements": num_failed_judgements, "num_good_judgements": num_good_judgements, + "num_user_interactions": num_user_interactions, } ) - def log_evolution_step_tries( + def log_evolution_step( self, *, generation: int, @@ -288,6 +301,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): num_bad_judgements: int, num_failed_judgements: int, num_model_calls: int, + num_user_interactions: int, ): if wandb.run is not None: wandb.log( @@ -298,9 +312,26 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): "evolution_step_num_bad_judgements": num_bad_judgements, "evolution_step_num_failed_judgements": num_failed_judgements, "evolution_step_num_model_calls": num_model_calls, + "evolution_step_num_user_interactions": num_user_interactions, } ) + # also logs generation, update_step and total_update_step for a prompt + def log_prompt( + self, + prompt: Prompt, + *, + generation: int, + update_step: int, + total_update_step: int, + ): + super().log_prompt( + prompt, + generation=generation, + update_step=update_step, + total_update_step=total_update_step, + ) + @weave.op() def run(self, num_iterations: int, debug: bool = False) -> None: # debug mode for quick run @@ -537,17 +568,21 @@ class GeneticAlgorithm(EvolutionAlgorithm): ), ) - self.log_evolution_step_tries( + self.log_evolution_step( generation=current_generation, evolution_step=1, did_succeed=judgement in (Judgement.GOOD, Judgement.NONE), num_bad_judgements=all_judgements.count(Judgement.BAD), num_failed_judgements=all_judgements.count(Judgement.FAIL), num_model_calls=len(all_judgements), + num_user_interactions=0, ) - self.log_evolution_tries( + self.log_evolution( generation=current_generation, + update_step=current_iteration, + total_update_step=current_generation * self.population_size + + current_iteration, num_bad_judgements=all_judgements.count(Judgement.BAD), num_failed_judgements=all_judgements.count(Judgement.FAIL), num_good_judgements=all_judgements.count(Judgement.GOOD), @@ -659,17 +694,21 @@ class DifferentialEvolution(EvolutionAlgorithm): evolution_step=1, ), ) - self.log_evolution_step_tries( + self.log_evolution_step( generation=current_generation, evolution_step=1, did_succeed=judgement in (Judgement.GOOD, Judgement.NONE), num_bad_judgements=all_judgements.count(Judgement.BAD), num_failed_judgements=all_judgements.count(Judgement.FAIL), num_model_calls=len(all_judgements), + num_user_interactions=0, ) - self.log_evolution_tries( + self.log_evolution( generation=current_generation, + update_step=current_iteration, + total_update_step=current_generation * self.population_size + + current_iteration, num_bad_judgements=all_judgements.count(Judgement.BAD), num_failed_judgements=all_judgements.count(Judgement.FAIL), num_good_judgements=all_judgements.count(Judgement.GOOD), @@ -768,6 +807,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): judgements: list[Judgement] = [] usage: ModelUsage = ModelUsage() evolution_judgements = [] + num_user_interactions = 0 for idx, (prompt, demo_messages) in enumerate( self._get_prompt_template(), start=1 ): @@ -813,13 +853,30 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): ) evolution_judgements += all_judgements - self.log_evolution_step_tries( + # If the response was modified, we add an annotation using the corrected response as the label + # TODO move to resolve_bad_response in _get_model_response (using callback)? + if resolution.corrected_response is not None: + self.annotation_handler.add_annotation( + generation=current_generation, + update_step=current_iteration, + evolution_step=idx, + prompt=prompt, + label=resolution.corrected_response, + incorrect_label=resolution.original_response, + ) + user_interaction = True + num_user_interactions += 1 + else: + user_interaction = False + + self.log_evolution_step( generation=current_generation, evolution_step=idx, did_succeed=judgement in (Judgement.GOOD, Judgement.NONE), num_bad_judgements=all_judgements.count(Judgement.BAD), num_failed_judgements=all_judgements.count(Judgement.FAIL), num_model_calls=len(all_judgements), + num_user_interactions=int(user_interaction), ) response = ( @@ -828,18 +885,6 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): else resolution.original_response ) - # If the response was modified, we add an annotation using the corrected response as the label - # TODO move to resolve_bad_response in _get_model_response (using callback)? - if resolution.corrected_response is not None: - self.annotation_handler.add_annotation( - generation=current_generation, - update_step=current_iteration, - evolution_step=idx, - prompt=prompt, - label=resolution.corrected_response, - incorrect_label=resolution.original_response, - ) - logger.debug( "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s", idx, @@ -854,13 +899,17 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): evolutions_steps.append( self.evolution_model._get_assistant_message(response) ) - self.log_evolution_tries( + self.log_evolution( generation=current_generation, + update_step=current_iteration, + total_update_step=current_generation * self.population_size + + current_iteration, num_bad_judgements=evolution_judgements.count(Judgement.BAD), num_failed_judgements=evolution_judgements.count(Judgement.FAIL), num_good_judgements=evolution_judgements.count(Judgement.GOOD), final_evolution_step=idx, num_model_calls=len(evolution_judgements), + num_user_interactions=num_user_interactions, ) evolved_prompt = self.parse_response(response) diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 6b77587..593bfe8 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -190,13 +190,6 @@ class AnnotationHandler: self.annotated_samples[identifier].append( AnnotatedSample(prompt, label, incorrect_label) ) - logger.info( - "Annotated sample (%d remaining): '%s' -> '%s' (incorrect: '%s')", - self.max_annotations - len(self), - prompt, - label, - incorrect_label, - ) def get_annotations(self) -> list[AnnotatedSample]: if self.strategy == "simple": @@ -297,33 +290,27 @@ class PromptOptimization: def log_prompt( self, prompt: Prompt, - *, - generation: int, - update_step: int, - total_update_step: int, + **kwargs, ): if wandb.run is not None: - wandb.log( - { - "prompts": wandb.Html( - tabulate( - ( - (prompt.content, prompt.score, prompt.meta["gen"]) - for prompt in self.all_prompts.values() - ), - tablefmt="html", - headers=["Prompt", "Score", "Gen"], - showindex=True, - ) - ), - # having different metrics makes it difficult to plot the data, so we just log the score under the metric name and under a generic name - f"validation/{self.task.metric_name}": prompt.score, - "validation/score": prompt.score, - "generation": generation, - "update_step": update_step, - "total_update_step": total_update_step, - }, - ) + dict_to_log = { + "prompts": wandb.Html( + tabulate( + ( + (prompt.content, prompt.score, prompt.meta["gen"]) + for prompt in self.all_prompts.values() + ), + tablefmt="html", + headers=["Prompt", "Score", "Gen"], + showindex=True, + ) + ), + # having different metrics makes it difficult to plot the data, so we just log the score under the metric name and under a generic name + f"validation/{self.task.metric_name}": prompt.score, + "validation/score": prompt.score, + } + dict_to_log.update(kwargs) + wandb.log(dict_to_log) def add_prompt( self, prompt: str, parents: tuple[Prompt] | None, meta: PromptMeta @@ -507,7 +494,7 @@ class PromptOptimization: ), "Annotation Check Function required." if not fn_check_for_annotation(): logger.info( - f"Prompt judged as bad, but annotation limit reached. Will not ask for correction." + f"Asked for correction, but annotation limit reached. Will not ask for correction." ) return JudgeResolution(response, None, skip=False) # let user skip or correct the response in an interactive way -- GitLab