diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index eb3114a5a5ffc089ac193970b514b052a0239e28..a5a7eaa60e0b92ca9117fff8c85e57e404c9c991 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -4,10 +4,10 @@ from abc import ABCMeta, abstractmethod from collections.abc import Iterable from typing import Any -import wandb import weave from tqdm import trange +import wandb from evoprompt.evolution.template_de import ( DE_DEMONSTRATION_DATA_CLS, DE_DEMONSTRATION_DATA_SIM, @@ -153,6 +153,23 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): } ) + def log_model_usage( + self, evaluation_usage: ModelUsage, evolution_usage: ModelUsage + ): + if wandb.run is not None: + wandb.log( + { + "usage/evaluation/prompt_tokens": evaluation_usage.prompt_tokens, + "usage/evaluation/completion_tokens": evaluation_usage.completion_tokens, + "usage/evaluation/total_tokens": evaluation_usage.total_tokens, + "usage/evolution/prompt_tokens": evolution_usage.prompt_tokens, + "usage/evolution/completion_tokens": evolution_usage.completion_tokens, + "usage/evolution/total_tokens": evolution_usage.total_tokens, + "usage/total": evaluation_usage.total_tokens + + evolution_usage.total_tokens, + } + ) + @weave.op() def run(self, num_iterations: int, debug: bool = False) -> None: # debug mode for quick run @@ -244,6 +261,9 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): test_performance, _, _ = self.task.evaluate_test(p.content) self.log_test_performance(test_performance) + self.log_model_usage( + self.total_evaluation_usage, self.total_evolution_usage + ) logger.info( "Best prompt on test set: %s %s",