From cbade5a01045300d7d017a06bb3fd7c818b8e7ee Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 1 Oct 2024 18:18:59 +0200 Subject: [PATCH] Fix strange behavior in AlpacaHfChat models --- evoprompt/models.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/evoprompt/models.py b/evoprompt/models.py index 2dabae9..2fa83f6 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -468,6 +468,7 @@ class HfChat(ChatModel, LLMModel): stop: str | None, max_tokens: int | None, enforce_randomness: bool, + **kwargs, ): # setup kwargs for model call model_call_kwargs = { @@ -512,6 +513,26 @@ class AlpacaHfChat(HfChat): # chat template for Alpaca adapted from https://huggingface.co/Vezora/Mistral-22B-v0.1/blob/c15d70465e2fc46c3c4d7fec8fb62f533d4ef09b/tokenizer_config.json#L30 self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message + '\\n\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip()}}{% endif %}{% endfor %}" + def _create_completion( + self, + messages: list[dict[str, str]], + *, + use_cache: bool, + stop: str | None, + max_tokens: int | None, + enforce_randomness: bool, + **kwargs, + ): + # for some reason adding an empty assistant message yields different generations than adding it manually in the chat template + return super()._create_completion( + messages + [self._get_assistant_message("")], + use_cache=use_cache, + stop=stop, + max_tokens=max_tokens, + enforce_randomness=enforce_randomness, + **kwargs, + ) + def _get_input_prefix(self): return "### Input:\n" -- GitLab