From d82b104a6e46ba7f8653721aa4b2a2ec28a059c3 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 1 Oct 2024 18:19:56 +0200 Subject: [PATCH] Get rid of warning during generation with HF pipeline --- evoprompt/models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 41b88a8..34eb729 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -409,7 +409,6 @@ class LlamaChat(ChatModel, Llama): **model_call_kwargs, ) response_text = response["choices"][0]["message"]["content"] - # input(response_text) usage = ModelUsage(**response["usage"]) return response_text, usage @@ -452,13 +451,15 @@ class HfChat(ChatModel, LLMModel): **model_kwargs, ) # Setting the pad token to the eos token to avoid stdout prints - # TODO sometimes there are multiple eos tokens, how to handle this? - if not isinstance( + # if there are multiple eos tokens, we use the first one (similarly to how it is done in the TF library) + if isinstance( self.pipeline.model.generation_config.eos_token_id, (list, tuple) ): - self.pipeline.model.generation_config.pad_token_id = ( - self.pipeline.model.generation_config.eos_token_id - ) + eos_token_id = self.pipeline.model.generation_config.eos_token_id[0] + else: + eos_token_id = self.pipeline.model.generation_config.eos_token_id + + self.pipeline.model.generation_config.pad_token_id = eos_token_id def _create_completion( self, -- GitLab