From be7f82c0158e726bb93a4d1a4cbdbad2f7cf83dd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 25 Oct 2024 16:13:14 +0200
Subject: [PATCH] use corrext wandb entity

---
 evoprompt/models.py | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 248bf94..7c3e455 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -498,6 +498,22 @@ class HfChat(ChatModel, LLMModel):
         import torch
         import transformers
 
+        class UsageGenerationPipeline(transformers.TextGenerationPipeline):
+            pass
+            # def run_single(
+            #     self, inputs, preprocess_params, forward_params, postprocess_params
+            # ):
+            #     model_inputs = self.preprocess(inputs, **preprocess_params)
+            #     print(model_inputs["input_ids"].shape)
+            #     model_outputs = self.forward(model_inputs, **forward_params)
+            #     print(type(model_outputs), model_outputs)
+            #     outputs = self.postprocess(model_outputs, **postprocess_params)
+            #     return outputs
+
+        # transformers.pipelines.SUPPORTED_TASKS["text-generation"][
+        #     "impl"
+        # ] = UsageGenerationPipeline
+
         self._model_name = model
 
         # we collect all arguments to make sure they are passed to the super constructor
-- 
GitLab