From 537d12d9f8b15334be24fe55c82e2f3a44a4c519 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Fri, 4 Oct 2024 21:08:57 +0200
Subject: [PATCH] Use chat-style messages for demonstration samples considering
 the model format

---
 evoprompt/models.py    | 152 +++++++++++++++++++++++++++--------------
 evoprompt/task/task.py |  75 ++++++++++++--------
 evoprompt/utils.py     |   4 +-
 main.py                |   5 +-
 4 files changed, 155 insertions(+), 81 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 8653c0d..4012ce4 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -1,8 +1,11 @@
+from collections.abc import Iterable
 import hashlib
 import inspect
 import json
 import logging
 import random
+import sqlite3
+import time
 import warnings
 from abc import ABC, abstractmethod
 from argparse import ArgumentParser
@@ -23,7 +26,8 @@ logger = logging.getLogger(__name__)
 logging.captureWarnings(True)
 warnings.simplefilter("once")
 
-ChatMessages = list[dict[str, str]]
+ChatMessage = dict[str, str]
+ChatMessages = list[ChatMessage]
 
 
 class LLMModel(ABC):
@@ -57,17 +61,29 @@ class LLMModel(ABC):
             LLMModel.loaded_models[cls] = (model, key)
         return model
 
-    def __init__(self, ignore_cache_kwargs: list[str] | None = None, **kwargs):
+    def __init__(
+        self, cache_kwargs: dict, ignore_cache_kwargs: list[str] | None = None, **kwargs
+    ):
         self.usage = ModelUsage()
 
         # store kwargs for caching
-        self.kwargs = kwargs.copy()
-        self.kwargs["ignore_cache_kwargs"] = ignore_cache_kwargs
+        self.cache_kwargs = cache_kwargs.copy()
+        self.cache_kwargs["ignore_cache_kwargs"] = ignore_cache_kwargs
+
+        self.max_tokens = kwargs.get("max_tokens", None)
 
         # set up caching for model calls
         self._call_model_cached = None
-        if not self.kwargs.get("disable_cache", False):
-            cache = Cache(Path(".cache_dir", self.model_cache_key))
+        if kwargs.get("disable_cache", False):
+            logger.info("Caching is disabled")
+        else:
+            while True:
+                try:
+                    cache = Cache(Path(".cache_dir", self.model_cache_key))
+                    break
+                except sqlite3.OperationalError:
+                    logger.warning("Failed to open cache, retrying...", exc_info=True)
+                    time.sleep(random.random() * 5)
 
             @cache.memoize(typed=True, ignore=[0, "func"])
             def _call_function(func, *args, **kwargs):
@@ -86,6 +102,7 @@ class LLMModel(ABC):
         use_randomness: bool = False,
         **kwargs,
     ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]:
+        raise NotImplementedError("Non-chat models are currently not supported")
         messages = [self._get_user_message(prompt)]
         if system_message is not None:
             prompt = system_message + prompt
@@ -95,13 +112,36 @@ class LLMModel(ABC):
             grammar=grammar,
             stop=stop,
             use_cache=use_cache,
-            max_tokens=self.kwargs.get("max_tokens", None),
+            max_tokens=self.max_tokens,
             use_randomness=use_randomness,
         )
 
         messages.append(self._get_assistant_message(reponse))
         return reponse, None, messages, usage
 
+    def build_demonstration_data(
+        self,
+        demonstrations: Iterable[tuple[str, str]],
+        instruction: str | None,
+        **kwargs,
+    ) -> ChatMessages:
+        if not isinstance(self, ChatModel):
+            raise ValueError(
+                f"Model {self} does not support building demonstration data"
+            )
+        messages = []
+        for input_, output in demonstrations:
+            messages.extend(
+                self.build_input_data(input_, instruction=instruction, **kwargs)[1]
+            )
+            messages.append(self._get_assistant_message(output))
+        return messages
+
+    def build_input_data(
+        self, prompt: str, instruction: str | None = None, **kwargs
+    ) -> ChatMessages:
+        return instruction, [self._get_user_message(prompt)]
+
     def _get_prediction_prefix(self):
         # some models use a special token prefix for the prediction
         return None
@@ -183,7 +223,7 @@ class LLMModel(ABC):
         return (
             str(self.model_name).replace("/", "_")
             + "/"
-            + self.get_hash_from_kwargs(**self.kwargs)
+            + self.get_hash_from_kwargs(**self.cache_kwargs)
         )
 
     @property
@@ -255,7 +295,11 @@ class Llama(LLMModel):
 
         # pass all arguments to super constructor which should be taken into account for caching
         # needs to be called after model is initialized
-        super().__init__(ignore_cache_kwargs=ignore_cache_kwargs, **hashed_model_kwargs)
+        super().__init__(
+            cache_kwargs=hashed_model_kwargs,
+            ignore_cache_kwargs=ignore_cache_kwargs,
+            **kwargs,
+        )
 
     def _create_completion(
         self,
@@ -339,9 +383,10 @@ class ChatModel:
     @weave.op()
     def create_completion(
         self,
-        system_message: str | None,
-        prompt: str,
+        system_message: str | ChatMessage | None,
+        messages: ChatMessages | None = None,
         *,
+        prompt: str | None = None,
         use_cache: bool = False,
         grammar: llama_cpp.LlamaGrammar | None = None,
         stop: str | None = None,
@@ -349,25 +394,28 @@ class ChatModel:
         use_randomness: bool = False,
         **kwargs,
     ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]:
-        # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
-        # TODO is it better to check for a system message in the history?
-        messages = [self._get_user_message(prompt)]
+        if messages is None:
+            assert prompt is not None, "Either messages or prompt must be provided"
+            messages = [self._get_user_message(prompt)]
 
-        if history is None:
-            if system_message:
-                history = [self._get_system_message(system_message)]
-            else:
-                history = []
         # we prepend the history to the messages
         # the chat format should take care of adding appropriate assistant messages for generating the completion
-        messages_for_model = history + messages
+        messages_for_model = messages
+        if history is None:
+            history = []
+        messages_for_model = history + messages_for_model
+        # prepend system message if available
+        if system_message is not None:
+            if isinstance(system_message, str):
+                system_message = self._get_system_message(system_message)
+            messages_for_model = [system_message] + messages_for_model
 
         reponse, usage = self._create_completion(
             messages=messages_for_model,
             grammar=grammar,
             stop=stop,
             use_cache=use_cache,
-            max_tokens=self.kwargs.get("max_tokens", None),
+            max_tokens=self.max_tokens,
             use_randomness=use_randomness,
         )
 
@@ -389,7 +437,7 @@ class LlamaChat(ChatModel, Llama):
         use_randomness: bool,
     ):
         # input(
-        #     f"The input for the model will look like this:\n{format_llama3(messages).prompt}"
+        #     f"The input for a Llama3.x model will look like this:\n{format_llama3(messages).prompt}"
         # )
         # setup kwargs for model call
         model_call_kwargs = {
@@ -446,7 +494,9 @@ class HfChat(ChatModel, LLMModel):
         ignore_cache_kwargs.extend(["torch_dtype"])
 
         # pass all arguments to super constructor which should be taken into account for caching
-        super().__init__(**model_kwargs, ignore_cache_kwargs=ignore_cache_kwargs)
+        super().__init__(
+            cache_kwargs=model_kwargs, ignore_cache_kwargs=ignore_cache_kwargs, **kwargs
+        )
 
         # initialize model
         self.pipeline = transformers.pipeline(
@@ -487,7 +537,7 @@ class HfChat(ChatModel, LLMModel):
             model_call_kwargs["do_sample"] = False
 
         # input(
-        #     f"The input for the model will look like this:\n{self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False)}"
+        #     f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'"
         # )
         response = self._call_model(
             self.pipeline,
@@ -506,41 +556,39 @@ class HfChat(ChatModel, LLMModel):
 
 # For Alpaca we build inputs to follow the fine-tuning chat format like https://github.com/tatsu-lab/stanford_alpaca/blob/761dc5bfbdeeffa89b8bff5d038781a4055f796a/README.md?plain=1#L56-L66
 class AlpacaHfChat(HfChat):
+    SYSTEM_MESSAGE = "You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
+
     def __init__(
         self,
         model: str = "chavinlo/alpaca-native",
-        ignore_cache_kwargs: list[str] | None = None,
         **kwargs,
     ):
-        super().__init__(model=model, ignore_cache_kwargs=ignore_cache_kwargs, **kwargs)
+        super().__init__(model=model, **kwargs)
 
         # chat template for Alpaca adapted from https://huggingface.co/Vezora/Mistral-22B-v0.1/blob/c15d70465e2fc46c3c4d7fec8fb62f533d4ef09b/tokenizer_config.json#L30
-        self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message + '\\n\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n'  + message['content'].strip()}}{% endif %}{% endfor %}"
-
-    def _create_completion(
-        self,
-        messages: list[dict[str, str]],
-        *,
-        use_cache: bool,
-        stop: str | None,
-        max_tokens: int | None,
-        use_randomness: bool,
-        **kwargs,
-    ):
-        # for some reason adding an empty assistant message yields different generations than adding it manually in the chat template
-        return super()._create_completion(
-            messages + [self._get_assistant_message("")],
-            use_cache=use_cache,
-            stop=stop,
-            max_tokens=max_tokens,
-            use_randomness=use_randomness,
-            **kwargs,
-        )
+        self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '\\n\\n### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n'  + message['content'].strip()}}{% endif %}{% endfor %}"
+
+    def build_input_data(
+        self, prompt: str, instruction: str = None, **kwargs
+    ) -> ChatMessages:
+        # For Alpaca we add the instruction for each input (and assume that content is the context, otherwise the content should contain the instruction)
+        return self.SYSTEM_MESSAGE, [
+            self._get_user_message(
+                (
+                    (instruction + "\n\n" + self._get_input_prefix())
+                    if instruction is not None
+                    else ""
+                )
+                + prompt
+            )
+        ]
 
-    def _get_input_prefix(self):
+    @staticmethod
+    def _get_input_prefix():
         return "### Input:\n"
 
-    def _get_prediction_prefix(self):
+    @staticmethod
+    def _get_prediction_prefix():
         # Alpaca uses a special token prefix for the prediction
         return "\n### Response:\n"
 
@@ -559,7 +607,11 @@ class OpenAIChat(ChatModel, LLMModel):
         self.openai_client = openai.OpenAI()
         self._model_name = openai_model
 
-        super().__init__(model=openai_model, ignore_cache_kwargs=ignore_cache_kwargs)
+        super().__init__(
+            cache_kwargs=dict(model=openai_model),
+            ignore_cache_kwargs=ignore_cache_kwargs,
+            **kwargs,
+        )
 
     def _create_completion(
         self,
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 6b2a334..ae260f5 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -10,15 +10,13 @@ from datasets import Dataset, load_dataset
 from llama_cpp import LlamaGrammar
 from tqdm import tqdm
 
-from evoprompt.models import LLMModel
+from evoprompt.models import ChatMessage, ChatMessages, LLMModel
 from evoprompt.opt_types import ModelUsage
 from evoprompt.utils import log_calls
 
 logger = logging.getLogger(__name__)
 
 
-SYSTEM_MESSAGE = "You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
-
 DatasetDatum = dict
 
 
@@ -344,7 +342,6 @@ class Task(metaclass=ABCMeta):
     ) -> Iterable[int]:
         pass
 
-
     @log_calls("Evaluating validation dataset")
     def evaluate_validation(
         self, prompt: str, parent_histories: list[list[float]] | None = None
@@ -371,11 +368,15 @@ class Task(metaclass=ABCMeta):
         evaluation_history = []
 
         # augment prompt with demonstration samples
-        prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt)
+        demonstration_prompt = self.build_demonstration_prompt(
+            self.demonstration_samples, instruction=prompt
+        )
 
         for datum in dataset_iterator:
             # run prediction
-            response, usage = self.predict(prompt=prompt_with_examples, datum=datum)
+            response, usage = self.predict(
+                instruction=prompt, history=demonstration_prompt, datum=datum
+            )
             logger.debug(f"Response: '{response}'")
             # parse response
             response = self._parse_response(response=response)
@@ -399,17 +400,22 @@ class Task(metaclass=ABCMeta):
                 break
 
         return self._aggregate_result(results), evaluation_usage, evaluation_history
-    
+
     @weave.op()
-    def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
+    def predict(
+        self, instruction: str, history: ChatMessages, datum: DatasetDatum
+    ) -> tuple[str, ModelUsage]:
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
         # build prompt for current sample
-        prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None)
+        system_message, prompt_for_datum = self.build_prompt_input(
+            datum, instruction, use_prediction_prefix=self.force_task_prediction_prefix
+        )
         logger.debug(f"Prompt for datum:\n{prompt_for_datum}")
         response, _, _, usage = self.model.create_completion(
-            system_message=SYSTEM_MESSAGE,
-            prompt=prompt_for_datum,
+            system_message=system_message,
+            messages=prompt_for_datum,
+            history=history,
             # grammar can be applied to constrain the model output
             grammar=self._get_grammar(datum) if self.use_grammar else None,
             # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result
@@ -423,28 +429,43 @@ class Task(metaclass=ABCMeta):
             response = response.strip()
 
         return response, usage
-    
+
     def build_prompt_input(
-        self, sample, prompt: str = "", use_prediction_prefix: bool = False,
-    ) -> str:
+        self,
+        sample: DatasetDatum,
+        instruction: str,
+        use_prediction_prefix: bool = False,
+    ) -> tuple[ChatMessage, ChatMessages]:
         # the default is to use the prompt as is and concatenate the datum string
-        prompt += f"\n\n{self.model._get_input_prefix() if self.model._get_input_prefix() is not None else ""}{self._get_prompt_text_for_datum(sample, use_prefix=self.force_task_input_prefix or not self.model._get_input_prefix())}"
+        prompt = self._get_prompt_text_for_datum(
+            sample, use_prefix=self.force_task_input_prefix
+        )
         if use_prediction_prefix:
-            prompt += f"\n{self._get_prediction_prefix().strip()} "
-        return prompt.strip()
-    
+            prompt += f"\n{self._get_prediction_prefix().strip()}"
+        return self.model.build_input_data(prompt, instruction)
+
     def build_demonstration_prompt(
         self,
-        demonstration_samples: list[dict],
-        prompt: str = "",
-    ) -> str:
-        for sample in demonstration_samples:
-            prompt += "\n\n" + self.build_prompt_input(sample)
-            prompt += f"\n{self.model._get_prediction_prefix() if self.model._get_prediction_prefix() is not None else self._get_prediction_prefix()}{self._get_gold_label_generation_for_datum(sample)}"
-        return prompt.strip()
-    
+        demonstration_samples: Iterable[DatasetDatum],
+        instruction: str = None,
+    ) -> ChatMessages:
+        return self.model.build_demonstration_data(
+            [
+                (
+                    self._get_prompt_text_for_datum(
+                        sample, use_prefix=self.force_task_input_prefix
+                    ),
+                    self._get_gold_label_generation_for_datum(sample),
+                )
+                for sample in demonstration_samples
+            ],
+            instruction=instruction,
+        )
+
     @abstractmethod
-    def _get_prompt_text_for_datum(self, datum: DatasetDatum, use_prefix: bool = False) -> str: ...
+    def _get_prompt_text_for_datum(
+        self, datum: DatasetDatum, use_prefix: bool = False
+    ) -> str: ...
 
     @abstractmethod
     def _get_prediction_prefix() -> str: ...
diff --git a/evoprompt/utils.py b/evoprompt/utils.py
index 701ae38..2d41444 100644
--- a/evoprompt/utils.py
+++ b/evoprompt/utils.py
@@ -57,7 +57,7 @@ def setup_console_logger(verbosity_level: int = 0):
     logging.basicConfig(handlers=(console_handler,), level=logging.NOTSET)
 
 
-run_name_prompt = (
+RUN_NAME_PROMPT = (
     "Generate a random name that sounds german or dutch. "
     "The parts should be separated by underscores and contain only lowercase. "
     "Only return the name without any text before or after."
@@ -80,7 +80,7 @@ def initialize_run_directory(model):
     # make sure that we use high randomness for generating the run name even if a seed is set for the model
     response, _, _, _ = model.create_completion(
         system_message=None,
-        prompt=run_name_prompt,
+        prompt=RUN_NAME_PROMPT,
         use_randomness=True,
     )
     run_name_match = re.search(r"^\w+$", response, re.MULTILINE)
diff --git a/main.py b/main.py
index 40c8d25..f74fbea 100644
--- a/main.py
+++ b/main.py
@@ -59,8 +59,9 @@ if __name__ == "__main__":
 
     if options.wandb_project is not None:
         # init wandb and weave tracing (with disabled call link printing)
-        weave_settings = UserSettings(disabled=False, print_call_link=False)
-        weave.init(project_name=options.wandb_project, settings=weave_settings)
+        # TODO weave recently had 500 errors quite often, so we disable it for now
+        # weave_settings = UserSettings(disabled=False, print_call_link=False)
+        # weave.init(project_name=options.wandb_project, settings=weave_settings)
         wandb.init(project=options.wandb_project, config=options.__dict__)
 
     # set up console logging and rnd
-- 
GitLab