From 537d12d9f8b15334be24fe55c82e2f3a44a4c519 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Fri, 4 Oct 2024 21:08:57 +0200 Subject: [PATCH] Use chat-style messages for demonstration samples considering the model format --- evoprompt/models.py | 152 +++++++++++++++++++++++++++-------------- evoprompt/task/task.py | 75 ++++++++++++-------- evoprompt/utils.py | 4 +- main.py | 5 +- 4 files changed, 155 insertions(+), 81 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 8653c0d..4012ce4 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,8 +1,11 @@ +from collections.abc import Iterable import hashlib import inspect import json import logging import random +import sqlite3 +import time import warnings from abc import ABC, abstractmethod from argparse import ArgumentParser @@ -23,7 +26,8 @@ logger = logging.getLogger(__name__) logging.captureWarnings(True) warnings.simplefilter("once") -ChatMessages = list[dict[str, str]] +ChatMessage = dict[str, str] +ChatMessages = list[ChatMessage] class LLMModel(ABC): @@ -57,17 +61,29 @@ class LLMModel(ABC): LLMModel.loaded_models[cls] = (model, key) return model - def __init__(self, ignore_cache_kwargs: list[str] | None = None, **kwargs): + def __init__( + self, cache_kwargs: dict, ignore_cache_kwargs: list[str] | None = None, **kwargs + ): self.usage = ModelUsage() # store kwargs for caching - self.kwargs = kwargs.copy() - self.kwargs["ignore_cache_kwargs"] = ignore_cache_kwargs + self.cache_kwargs = cache_kwargs.copy() + self.cache_kwargs["ignore_cache_kwargs"] = ignore_cache_kwargs + + self.max_tokens = kwargs.get("max_tokens", None) # set up caching for model calls self._call_model_cached = None - if not self.kwargs.get("disable_cache", False): - cache = Cache(Path(".cache_dir", self.model_cache_key)) + if kwargs.get("disable_cache", False): + logger.info("Caching is disabled") + else: + while True: + try: + cache = Cache(Path(".cache_dir", self.model_cache_key)) + break + except sqlite3.OperationalError: + logger.warning("Failed to open cache, retrying...", exc_info=True) + time.sleep(random.random() * 5) @cache.memoize(typed=True, ignore=[0, "func"]) def _call_function(func, *args, **kwargs): @@ -86,6 +102,7 @@ class LLMModel(ABC): use_randomness: bool = False, **kwargs, ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]: + raise NotImplementedError("Non-chat models are currently not supported") messages = [self._get_user_message(prompt)] if system_message is not None: prompt = system_message + prompt @@ -95,13 +112,36 @@ class LLMModel(ABC): grammar=grammar, stop=stop, use_cache=use_cache, - max_tokens=self.kwargs.get("max_tokens", None), + max_tokens=self.max_tokens, use_randomness=use_randomness, ) messages.append(self._get_assistant_message(reponse)) return reponse, None, messages, usage + def build_demonstration_data( + self, + demonstrations: Iterable[tuple[str, str]], + instruction: str | None, + **kwargs, + ) -> ChatMessages: + if not isinstance(self, ChatModel): + raise ValueError( + f"Model {self} does not support building demonstration data" + ) + messages = [] + for input_, output in demonstrations: + messages.extend( + self.build_input_data(input_, instruction=instruction, **kwargs)[1] + ) + messages.append(self._get_assistant_message(output)) + return messages + + def build_input_data( + self, prompt: str, instruction: str | None = None, **kwargs + ) -> ChatMessages: + return instruction, [self._get_user_message(prompt)] + def _get_prediction_prefix(self): # some models use a special token prefix for the prediction return None @@ -183,7 +223,7 @@ class LLMModel(ABC): return ( str(self.model_name).replace("/", "_") + "/" - + self.get_hash_from_kwargs(**self.kwargs) + + self.get_hash_from_kwargs(**self.cache_kwargs) ) @property @@ -255,7 +295,11 @@ class Llama(LLMModel): # pass all arguments to super constructor which should be taken into account for caching # needs to be called after model is initialized - super().__init__(ignore_cache_kwargs=ignore_cache_kwargs, **hashed_model_kwargs) + super().__init__( + cache_kwargs=hashed_model_kwargs, + ignore_cache_kwargs=ignore_cache_kwargs, + **kwargs, + ) def _create_completion( self, @@ -339,9 +383,10 @@ class ChatModel: @weave.op() def create_completion( self, - system_message: str | None, - prompt: str, + system_message: str | ChatMessage | None, + messages: ChatMessages | None = None, *, + prompt: str | None = None, use_cache: bool = False, grammar: llama_cpp.LlamaGrammar | None = None, stop: str | None = None, @@ -349,25 +394,28 @@ class ChatModel: use_randomness: bool = False, **kwargs, ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]: - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message in the history? - messages = [self._get_user_message(prompt)] + if messages is None: + assert prompt is not None, "Either messages or prompt must be provided" + messages = [self._get_user_message(prompt)] - if history is None: - if system_message: - history = [self._get_system_message(system_message)] - else: - history = [] # we prepend the history to the messages # the chat format should take care of adding appropriate assistant messages for generating the completion - messages_for_model = history + messages + messages_for_model = messages + if history is None: + history = [] + messages_for_model = history + messages_for_model + # prepend system message if available + if system_message is not None: + if isinstance(system_message, str): + system_message = self._get_system_message(system_message) + messages_for_model = [system_message] + messages_for_model reponse, usage = self._create_completion( messages=messages_for_model, grammar=grammar, stop=stop, use_cache=use_cache, - max_tokens=self.kwargs.get("max_tokens", None), + max_tokens=self.max_tokens, use_randomness=use_randomness, ) @@ -389,7 +437,7 @@ class LlamaChat(ChatModel, Llama): use_randomness: bool, ): # input( - # f"The input for the model will look like this:\n{format_llama3(messages).prompt}" + # f"The input for a Llama3.x model will look like this:\n{format_llama3(messages).prompt}" # ) # setup kwargs for model call model_call_kwargs = { @@ -446,7 +494,9 @@ class HfChat(ChatModel, LLMModel): ignore_cache_kwargs.extend(["torch_dtype"]) # pass all arguments to super constructor which should be taken into account for caching - super().__init__(**model_kwargs, ignore_cache_kwargs=ignore_cache_kwargs) + super().__init__( + cache_kwargs=model_kwargs, ignore_cache_kwargs=ignore_cache_kwargs, **kwargs + ) # initialize model self.pipeline = transformers.pipeline( @@ -487,7 +537,7 @@ class HfChat(ChatModel, LLMModel): model_call_kwargs["do_sample"] = False # input( - # f"The input for the model will look like this:\n{self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False)}" + # f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'" # ) response = self._call_model( self.pipeline, @@ -506,41 +556,39 @@ class HfChat(ChatModel, LLMModel): # For Alpaca we build inputs to follow the fine-tuning chat format like https://github.com/tatsu-lab/stanford_alpaca/blob/761dc5bfbdeeffa89b8bff5d038781a4055f796a/README.md?plain=1#L56-L66 class AlpacaHfChat(HfChat): + SYSTEM_MESSAGE = "You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." + def __init__( self, model: str = "chavinlo/alpaca-native", - ignore_cache_kwargs: list[str] | None = None, **kwargs, ): - super().__init__(model=model, ignore_cache_kwargs=ignore_cache_kwargs, **kwargs) + super().__init__(model=model, **kwargs) # chat template for Alpaca adapted from https://huggingface.co/Vezora/Mistral-22B-v0.1/blob/c15d70465e2fc46c3c4d7fec8fb62f533d4ef09b/tokenizer_config.json#L30 - self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message + '\\n\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip()}}{% endif %}{% endfor %}" - - def _create_completion( - self, - messages: list[dict[str, str]], - *, - use_cache: bool, - stop: str | None, - max_tokens: int | None, - use_randomness: bool, - **kwargs, - ): - # for some reason adding an empty assistant message yields different generations than adding it manually in the chat template - return super()._create_completion( - messages + [self._get_assistant_message("")], - use_cache=use_cache, - stop=stop, - max_tokens=max_tokens, - use_randomness=use_randomness, - **kwargs, - ) + self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '\\n\\n### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip()}}{% endif %}{% endfor %}" + + def build_input_data( + self, prompt: str, instruction: str = None, **kwargs + ) -> ChatMessages: + # For Alpaca we add the instruction for each input (and assume that content is the context, otherwise the content should contain the instruction) + return self.SYSTEM_MESSAGE, [ + self._get_user_message( + ( + (instruction + "\n\n" + self._get_input_prefix()) + if instruction is not None + else "" + ) + + prompt + ) + ] - def _get_input_prefix(self): + @staticmethod + def _get_input_prefix(): return "### Input:\n" - def _get_prediction_prefix(self): + @staticmethod + def _get_prediction_prefix(): # Alpaca uses a special token prefix for the prediction return "\n### Response:\n" @@ -559,7 +607,11 @@ class OpenAIChat(ChatModel, LLMModel): self.openai_client = openai.OpenAI() self._model_name = openai_model - super().__init__(model=openai_model, ignore_cache_kwargs=ignore_cache_kwargs) + super().__init__( + cache_kwargs=dict(model=openai_model), + ignore_cache_kwargs=ignore_cache_kwargs, + **kwargs, + ) def _create_completion( self, diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 6b2a334..ae260f5 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -10,15 +10,13 @@ from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar from tqdm import tqdm -from evoprompt.models import LLMModel +from evoprompt.models import ChatMessage, ChatMessages, LLMModel from evoprompt.opt_types import ModelUsage from evoprompt.utils import log_calls logger = logging.getLogger(__name__) -SYSTEM_MESSAGE = "You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." - DatasetDatum = dict @@ -344,7 +342,6 @@ class Task(metaclass=ABCMeta): ) -> Iterable[int]: pass - @log_calls("Evaluating validation dataset") def evaluate_validation( self, prompt: str, parent_histories: list[list[float]] | None = None @@ -371,11 +368,15 @@ class Task(metaclass=ABCMeta): evaluation_history = [] # augment prompt with demonstration samples - prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt) + demonstration_prompt = self.build_demonstration_prompt( + self.demonstration_samples, instruction=prompt + ) for datum in dataset_iterator: # run prediction - response, usage = self.predict(prompt=prompt_with_examples, datum=datum) + response, usage = self.predict( + instruction=prompt, history=demonstration_prompt, datum=datum + ) logger.debug(f"Response: '{response}'") # parse response response = self._parse_response(response=response) @@ -399,17 +400,22 @@ class Task(metaclass=ABCMeta): break return self._aggregate_result(results), evaluation_usage, evaluation_history - + @weave.op() - def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: + def predict( + self, instruction: str, history: ChatMessages, datum: DatasetDatum + ) -> tuple[str, ModelUsage]: # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? # build prompt for current sample - prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None) + system_message, prompt_for_datum = self.build_prompt_input( + datum, instruction, use_prediction_prefix=self.force_task_prediction_prefix + ) logger.debug(f"Prompt for datum:\n{prompt_for_datum}") response, _, _, usage = self.model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=prompt_for_datum, + system_message=system_message, + messages=prompt_for_datum, + history=history, # grammar can be applied to constrain the model output grammar=self._get_grammar(datum) if self.use_grammar else None, # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result @@ -423,28 +429,43 @@ class Task(metaclass=ABCMeta): response = response.strip() return response, usage - + def build_prompt_input( - self, sample, prompt: str = "", use_prediction_prefix: bool = False, - ) -> str: + self, + sample: DatasetDatum, + instruction: str, + use_prediction_prefix: bool = False, + ) -> tuple[ChatMessage, ChatMessages]: # the default is to use the prompt as is and concatenate the datum string - prompt += f"\n\n{self.model._get_input_prefix() if self.model._get_input_prefix() is not None else ""}{self._get_prompt_text_for_datum(sample, use_prefix=self.force_task_input_prefix or not self.model._get_input_prefix())}" + prompt = self._get_prompt_text_for_datum( + sample, use_prefix=self.force_task_input_prefix + ) if use_prediction_prefix: - prompt += f"\n{self._get_prediction_prefix().strip()} " - return prompt.strip() - + prompt += f"\n{self._get_prediction_prefix().strip()}" + return self.model.build_input_data(prompt, instruction) + def build_demonstration_prompt( self, - demonstration_samples: list[dict], - prompt: str = "", - ) -> str: - for sample in demonstration_samples: - prompt += "\n\n" + self.build_prompt_input(sample) - prompt += f"\n{self.model._get_prediction_prefix() if self.model._get_prediction_prefix() is not None else self._get_prediction_prefix()}{self._get_gold_label_generation_for_datum(sample)}" - return prompt.strip() - + demonstration_samples: Iterable[DatasetDatum], + instruction: str = None, + ) -> ChatMessages: + return self.model.build_demonstration_data( + [ + ( + self._get_prompt_text_for_datum( + sample, use_prefix=self.force_task_input_prefix + ), + self._get_gold_label_generation_for_datum(sample), + ) + for sample in demonstration_samples + ], + instruction=instruction, + ) + @abstractmethod - def _get_prompt_text_for_datum(self, datum: DatasetDatum, use_prefix: bool = False) -> str: ... + def _get_prompt_text_for_datum( + self, datum: DatasetDatum, use_prefix: bool = False + ) -> str: ... @abstractmethod def _get_prediction_prefix() -> str: ... diff --git a/evoprompt/utils.py b/evoprompt/utils.py index 701ae38..2d41444 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -57,7 +57,7 @@ def setup_console_logger(verbosity_level: int = 0): logging.basicConfig(handlers=(console_handler,), level=logging.NOTSET) -run_name_prompt = ( +RUN_NAME_PROMPT = ( "Generate a random name that sounds german or dutch. " "The parts should be separated by underscores and contain only lowercase. " "Only return the name without any text before or after." @@ -80,7 +80,7 @@ def initialize_run_directory(model): # make sure that we use high randomness for generating the run name even if a seed is set for the model response, _, _, _ = model.create_completion( system_message=None, - prompt=run_name_prompt, + prompt=RUN_NAME_PROMPT, use_randomness=True, ) run_name_match = re.search(r"^\w+$", response, re.MULTILINE) diff --git a/main.py b/main.py index 40c8d25..f74fbea 100644 --- a/main.py +++ b/main.py @@ -59,8 +59,9 @@ if __name__ == "__main__": if options.wandb_project is not None: # init wandb and weave tracing (with disabled call link printing) - weave_settings = UserSettings(disabled=False, print_call_link=False) - weave.init(project_name=options.wandb_project, settings=weave_settings) + # TODO weave recently had 500 errors quite often, so we disable it for now + # weave_settings = UserSettings(disabled=False, print_call_link=False) + # weave.init(project_name=options.wandb_project, settings=weave_settings) wandb.init(project=options.wandb_project, config=options.__dict__) # set up console logging and rnd -- GitLab