diff --git a/evoprompt/models.py b/evoprompt/models.py index 41b88a8e7d078c46e17a6068588db4fb5e911318..34eb729fee3fa07464e007a872320c4718157520 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -409,7 +409,6 @@ class LlamaChat(ChatModel, Llama): **model_call_kwargs, ) response_text = response["choices"][0]["message"]["content"] - # input(response_text) usage = ModelUsage(**response["usage"]) return response_text, usage @@ -452,13 +451,15 @@ class HfChat(ChatModel, LLMModel): **model_kwargs, ) # Setting the pad token to the eos token to avoid stdout prints - # TODO sometimes there are multiple eos tokens, how to handle this? - if not isinstance( + # if there are multiple eos tokens, we use the first one (similarly to how it is done in the TF library) + if isinstance( self.pipeline.model.generation_config.eos_token_id, (list, tuple) ): - self.pipeline.model.generation_config.pad_token_id = ( - self.pipeline.model.generation_config.eos_token_id - ) + eos_token_id = self.pipeline.model.generation_config.eos_token_id[0] + else: + eos_token_id = self.pipeline.model.generation_config.eos_token_id + + self.pipeline.model.generation_config.pad_token_id = eos_token_id def _create_completion( self,