From be7f82c0158e726bb93a4d1a4cbdbad2f7cf83dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Fri, 25 Oct 2024 16:13:14 +0200 Subject: [PATCH] use corrext wandb entity --- evoprompt/models.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/evoprompt/models.py b/evoprompt/models.py index 248bf94..7c3e455 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -498,6 +498,22 @@ class HfChat(ChatModel, LLMModel): import torch import transformers + class UsageGenerationPipeline(transformers.TextGenerationPipeline): + pass + # def run_single( + # self, inputs, preprocess_params, forward_params, postprocess_params + # ): + # model_inputs = self.preprocess(inputs, **preprocess_params) + # print(model_inputs["input_ids"].shape) + # model_outputs = self.forward(model_inputs, **forward_params) + # print(type(model_outputs), model_outputs) + # outputs = self.postprocess(model_outputs, **postprocess_params) + # return outputs + + # transformers.pipelines.SUPPORTED_TASKS["text-generation"][ + # "impl" + # ] = UsageGenerationPipeline + self._model_name = model # we collect all arguments to make sure they are passed to the super constructor -- GitLab