From 59b65e5d12766c71d21828954eb4b7bab2711ec8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 25 Oct 2024 16:31:35 +0200
Subject: [PATCH] calculate usage for hfchat models

---
 evoprompt/models.py | 36 ++++++++++++++++++++----------------
 1 file changed, 20 insertions(+), 16 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 7c3e455..10d3379 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
 from argparse import ArgumentParser
 from collections.abc import Iterable
 from itertools import zip_longest
+from math import prod
 from pathlib import Path
 from typing import Any, Callable, ClassVar
 
@@ -500,19 +501,23 @@ class HfChat(ChatModel, LLMModel):
 
         class UsageGenerationPipeline(transformers.TextGenerationPipeline):
             pass
-            # def run_single(
-            #     self, inputs, preprocess_params, forward_params, postprocess_params
-            # ):
-            #     model_inputs = self.preprocess(inputs, **preprocess_params)
-            #     print(model_inputs["input_ids"].shape)
-            #     model_outputs = self.forward(model_inputs, **forward_params)
-            #     print(type(model_outputs), model_outputs)
-            #     outputs = self.postprocess(model_outputs, **postprocess_params)
-            #     return outputs
-
-        # transformers.pipelines.SUPPORTED_TASKS["text-generation"][
-        #     "impl"
-        # ] = UsageGenerationPipeline
+
+            def run_single(
+                self, inputs, preprocess_params, forward_params, postprocess_params
+            ):
+                model_inputs = self.preprocess(inputs, **preprocess_params)
+                model_outputs = self.forward(model_inputs, **forward_params)
+                outputs = self.postprocess(model_outputs, **postprocess_params)
+                input_usage = prod(model_outputs["input_ids"].shape)
+                output_usage = prod(model_outputs["generated_sequence"].shape)
+                usage = ModelUsage(
+                    input_usage, output_usage, input_usage + output_usage
+                )
+                return outputs, usage
+
+        transformers.pipelines.SUPPORTED_TASKS["text-generation"][
+            "impl"
+        ] = UsageGenerationPipeline
 
         self._model_name = model
 
@@ -588,14 +593,13 @@ class HfChat(ChatModel, LLMModel):
         # input(
         #     f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'"
         # )
-        response = self._call_model(
+        response, usage = self._call_model(
             self.pipeline,
             use_cache=use_cache,
             **model_call_kwargs,
         )
         response_text = response[0]["generated_text"][-1]["content"]
-        # no usage supported by HF pipeline; TODO manually compute usage?
-        usage = ModelUsage()
+
         return response_text, usage
 
     @property
-- 
GitLab