Skip to content
Snippets Groups Projects
Commit d82b104a authored by Max Kimmich's avatar Max Kimmich
Browse files

Get rid of warning during generation with HF pipeline

parent 3faa07e9
No related branches found
No related tags found
1 merge request!7Refactor tasks and models and fix format for various models
...@@ -409,7 +409,6 @@ class LlamaChat(ChatModel, Llama): ...@@ -409,7 +409,6 @@ class LlamaChat(ChatModel, Llama):
**model_call_kwargs, **model_call_kwargs,
) )
response_text = response["choices"][0]["message"]["content"] response_text = response["choices"][0]["message"]["content"]
# input(response_text)
usage = ModelUsage(**response["usage"]) usage = ModelUsage(**response["usage"])
return response_text, usage return response_text, usage
...@@ -452,13 +451,15 @@ class HfChat(ChatModel, LLMModel): ...@@ -452,13 +451,15 @@ class HfChat(ChatModel, LLMModel):
**model_kwargs, **model_kwargs,
) )
# Setting the pad token to the eos token to avoid stdout prints # Setting the pad token to the eos token to avoid stdout prints
# TODO sometimes there are multiple eos tokens, how to handle this? # if there are multiple eos tokens, we use the first one (similarly to how it is done in the TF library)
if not isinstance( if isinstance(
self.pipeline.model.generation_config.eos_token_id, (list, tuple) self.pipeline.model.generation_config.eos_token_id, (list, tuple)
): ):
self.pipeline.model.generation_config.pad_token_id = ( eos_token_id = self.pipeline.model.generation_config.eos_token_id[0]
self.pipeline.model.generation_config.eos_token_id else:
) eos_token_id = self.pipeline.model.generation_config.eos_token_id
self.pipeline.model.generation_config.pad_token_id = eos_token_id
def _create_completion( def _create_completion(
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment