Skip to content
Snippets Groups Projects
Commit d82b104a authored by Max Kimmich's avatar Max Kimmich
Browse files

Get rid of warning during generation with HF pipeline

parent 3faa07e9
No related branches found
No related tags found
1 merge request!7Refactor tasks and models and fix format for various models
......@@ -409,7 +409,6 @@ class LlamaChat(ChatModel, Llama):
**model_call_kwargs,
)
response_text = response["choices"][0]["message"]["content"]
# input(response_text)
usage = ModelUsage(**response["usage"])
return response_text, usage
......@@ -452,13 +451,15 @@ class HfChat(ChatModel, LLMModel):
**model_kwargs,
)
# Setting the pad token to the eos token to avoid stdout prints
# TODO sometimes there are multiple eos tokens, how to handle this?
if not isinstance(
# if there are multiple eos tokens, we use the first one (similarly to how it is done in the TF library)
if isinstance(
self.pipeline.model.generation_config.eos_token_id, (list, tuple)
):
self.pipeline.model.generation_config.pad_token_id = (
self.pipeline.model.generation_config.eos_token_id
)
eos_token_id = self.pipeline.model.generation_config.eos_token_id[0]
else:
eos_token_id = self.pipeline.model.generation_config.eos_token_id
self.pipeline.model.generation_config.pad_token_id = eos_token_id
def _create_completion(
self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment