Skip to content
Snippets Groups Projects
Commit 59b65e5d authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

calculate usage for hfchat models

parent be7f82c0
No related branches found
No related tags found
1 merge request!9Hf usage
......@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from argparse import ArgumentParser
from collections.abc import Iterable
from itertools import zip_longest
from math import prod
from pathlib import Path
from typing import Any, Callable, ClassVar
......@@ -500,19 +501,23 @@ class HfChat(ChatModel, LLMModel):
class UsageGenerationPipeline(transformers.TextGenerationPipeline):
pass
# def run_single(
# self, inputs, preprocess_params, forward_params, postprocess_params
# ):
# model_inputs = self.preprocess(inputs, **preprocess_params)
# print(model_inputs["input_ids"].shape)
# model_outputs = self.forward(model_inputs, **forward_params)
# print(type(model_outputs), model_outputs)
# outputs = self.postprocess(model_outputs, **postprocess_params)
# return outputs
# transformers.pipelines.SUPPORTED_TASKS["text-generation"][
# "impl"
# ] = UsageGenerationPipeline
def run_single(
self, inputs, preprocess_params, forward_params, postprocess_params
):
model_inputs = self.preprocess(inputs, **preprocess_params)
model_outputs = self.forward(model_inputs, **forward_params)
outputs = self.postprocess(model_outputs, **postprocess_params)
input_usage = prod(model_outputs["input_ids"].shape)
output_usage = prod(model_outputs["generated_sequence"].shape)
usage = ModelUsage(
input_usage, output_usage, input_usage + output_usage
)
return outputs, usage
transformers.pipelines.SUPPORTED_TASKS["text-generation"][
"impl"
] = UsageGenerationPipeline
self._model_name = model
......@@ -588,14 +593,13 @@ class HfChat(ChatModel, LLMModel):
# input(
# f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'"
# )
response = self._call_model(
response, usage = self._call_model(
self.pipeline,
use_cache=use_cache,
**model_call_kwargs,
)
response_text = response[0]["generated_text"][-1]["content"]
# no usage supported by HF pipeline; TODO manually compute usage?
usage = ModelUsage()
return response_text, usage
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment