From d82b104a6e46ba7f8653721aa4b2a2ec28a059c3 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 1 Oct 2024 18:19:56 +0200
Subject: [PATCH] Get rid of warning during generation with HF pipeline

---
 evoprompt/models.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 41b88a8..34eb729 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -409,7 +409,6 @@ class LlamaChat(ChatModel, Llama):
             **model_call_kwargs,
         )
         response_text = response["choices"][0]["message"]["content"]
-        # input(response_text)
 
         usage = ModelUsage(**response["usage"])
         return response_text, usage
@@ -452,13 +451,15 @@ class HfChat(ChatModel, LLMModel):
             **model_kwargs,
         )
         # Setting the pad token to the eos token to avoid stdout prints
-        # TODO sometimes there are multiple eos tokens, how to handle this?
-        if not isinstance(
+        # if there are multiple eos tokens, we use the first one (similarly to how it is done in the TF library)
+        if isinstance(
             self.pipeline.model.generation_config.eos_token_id, (list, tuple)
         ):
-            self.pipeline.model.generation_config.pad_token_id = (
-                self.pipeline.model.generation_config.eos_token_id
-            )
+            eos_token_id = self.pipeline.model.generation_config.eos_token_id[0]
+        else:
+            eos_token_id = self.pipeline.model.generation_config.eos_token_id
+
+        self.pipeline.model.generation_config.pad_token_id = eos_token_id
 
     def _create_completion(
         self,
-- 
GitLab