From cbade5a01045300d7d017a06bb3fd7c818b8e7ee Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 1 Oct 2024 18:18:59 +0200
Subject: [PATCH] Fix strange behavior in AlpacaHfChat models

---
 evoprompt/models.py | 21 +++++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 2dabae9..2fa83f6 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -468,6 +468,7 @@ class HfChat(ChatModel, LLMModel):
         stop: str | None,
         max_tokens: int | None,
         enforce_randomness: bool,
+        **kwargs,
     ):
         # setup kwargs for model call
         model_call_kwargs = {
@@ -512,6 +513,26 @@ class AlpacaHfChat(HfChat):
         # chat template for Alpaca adapted from https://huggingface.co/Vezora/Mistral-22B-v0.1/blob/c15d70465e2fc46c3c4d7fec8fb62f533d4ef09b/tokenizer_config.json#L30
         self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message + '\\n\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n'  + message['content'].strip()}}{% endif %}{% endfor %}"
 
+    def _create_completion(
+        self,
+        messages: list[dict[str, str]],
+        *,
+        use_cache: bool,
+        stop: str | None,
+        max_tokens: int | None,
+        enforce_randomness: bool,
+        **kwargs,
+    ):
+        # for some reason adding an empty assistant message yields different generations than adding it manually in the chat template
+        return super()._create_completion(
+            messages + [self._get_assistant_message("")],
+            use_cache=use_cache,
+            stop=stop,
+            max_tokens=max_tokens,
+            enforce_randomness=enforce_randomness,
+            **kwargs,
+        )
+
     def _get_input_prefix(self):
         return "### Input:\n"
 
-- 
GitLab