From 96f6880f957c64978052a49f783ccb2421bc3150 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 25 Oct 2024 17:50:21 +0200
Subject: [PATCH] add usage metrics to wandb logs

---
 evoprompt/evolution/evolution.py | 22 +++++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index eb3114a..a5a7eaa 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -4,10 +4,10 @@ from abc import ABCMeta, abstractmethod
 from collections.abc import Iterable
 from typing import Any
 
-import wandb
 import weave
 from tqdm import trange
 
+import wandb
 from evoprompt.evolution.template_de import (
     DE_DEMONSTRATION_DATA_CLS,
     DE_DEMONSTRATION_DATA_SIM,
@@ -153,6 +153,23 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
                 }
             )
 
+    def log_model_usage(
+        self, evaluation_usage: ModelUsage, evolution_usage: ModelUsage
+    ):
+        if wandb.run is not None:
+            wandb.log(
+                {
+                    "usage/evaluation/prompt_tokens": evaluation_usage.prompt_tokens,
+                    "usage/evaluation/completion_tokens": evaluation_usage.completion_tokens,
+                    "usage/evaluation/total_tokens": evaluation_usage.total_tokens,
+                    "usage/evolution/prompt_tokens": evolution_usage.prompt_tokens,
+                    "usage/evolution/completion_tokens": evolution_usage.completion_tokens,
+                    "usage/evolution/total_tokens": evolution_usage.total_tokens,
+                    "usage/total": evaluation_usage.total_tokens
+                    + evolution_usage.total_tokens,
+                }
+            )
+
     @weave.op()
     def run(self, num_iterations: int, debug: bool = False) -> None:
         # debug mode for quick run
@@ -244,6 +261,9 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
             test_performance, _, _ = self.task.evaluate_test(p.content)
 
             self.log_test_performance(test_performance)
+            self.log_model_usage(
+                self.total_evaluation_usage, self.total_evolution_usage
+            )
 
             logger.info(
                 "Best prompt on test set: %s %s",
-- 
GitLab