From 59b65e5d12766c71d21828954eb4b7bab2711ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Fri, 25 Oct 2024 16:31:35 +0200 Subject: [PATCH] calculate usage for hfchat models --- evoprompt/models.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 7c3e455..10d3379 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from argparse import ArgumentParser from collections.abc import Iterable from itertools import zip_longest +from math import prod from pathlib import Path from typing import Any, Callable, ClassVar @@ -500,19 +501,23 @@ class HfChat(ChatModel, LLMModel): class UsageGenerationPipeline(transformers.TextGenerationPipeline): pass - # def run_single( - # self, inputs, preprocess_params, forward_params, postprocess_params - # ): - # model_inputs = self.preprocess(inputs, **preprocess_params) - # print(model_inputs["input_ids"].shape) - # model_outputs = self.forward(model_inputs, **forward_params) - # print(type(model_outputs), model_outputs) - # outputs = self.postprocess(model_outputs, **postprocess_params) - # return outputs - - # transformers.pipelines.SUPPORTED_TASKS["text-generation"][ - # "impl" - # ] = UsageGenerationPipeline + + def run_single( + self, inputs, preprocess_params, forward_params, postprocess_params + ): + model_inputs = self.preprocess(inputs, **preprocess_params) + model_outputs = self.forward(model_inputs, **forward_params) + outputs = self.postprocess(model_outputs, **postprocess_params) + input_usage = prod(model_outputs["input_ids"].shape) + output_usage = prod(model_outputs["generated_sequence"].shape) + usage = ModelUsage( + input_usage, output_usage, input_usage + output_usage + ) + return outputs, usage + + transformers.pipelines.SUPPORTED_TASKS["text-generation"][ + "impl" + ] = UsageGenerationPipeline self._model_name = model @@ -588,14 +593,13 @@ class HfChat(ChatModel, LLMModel): # input( # f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'" # ) - response = self._call_model( + response, usage = self._call_model( self.pipeline, use_cache=use_cache, **model_call_kwargs, ) response_text = response[0]["generated_text"][-1]["content"] - # no usage supported by HF pipeline; TODO manually compute usage? - usage = ModelUsage() + return response_text, usage @property -- GitLab