diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index 57556f0c0b5716fc12390d01076a84377df03261..e57e6817dd3b93b934d78e5ea6ece439a1c15475 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import logging import re from abc import ABCMeta, abstractmethod @@ -7,7 +8,6 @@ from tqdm import trange import wandb import weave -from evoprompt.evolution.template import get_demonstration_prompt_template from evoprompt.evolution.template_de import ( DE_DEMONSTRATION_DATA_CLS, DE_DEMONSTRATION_DATA_SIM, @@ -23,7 +23,7 @@ from evoprompt.evolution.template_ga import ( GA_DEMONSTRATION_DATA_SIM, GA_PROMPT, ) -from evoprompt.models import LLMModel +from evoprompt.models import ChatMessages, LLMModel from evoprompt.opt_types import ModelUsage, Prompt from evoprompt.optimization import Judgement, PromptOptimization from evoprompt.task import Task @@ -50,14 +50,12 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): *, task: Task, evolution_model: LLMModel, - evaluation_model: LLMModel, judge_model: LLMModel | None, run_options: dict[str, Any] = {}, ) -> None: super().__init__( task=task, evolution_model=evolution_model, - evaluation_model=evaluation_model, judge_model=judge_model, run_options=run_options, ) @@ -93,6 +91,43 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): ) -> tuple[str, list[Judgement], ModelUsage]: pass + def build_demonstration_prompt( + self, + demonstration_samples: Iterable[tuple[str, str]], + instruction: str = None, + ) -> ChatMessages: + return self.evolution_model.build_demonstration_data( + demonstration_samples, + instruction=instruction, + ) + + def parse_response( + self, + response: str, + start_token: str = "<prompt>", + end_token: str = "</prompt>", + allow_missing_end_token: bool = True, + ): + # TODO another option would be to select the first match that is not equal to " and " (which is part of the instruction and usually repeated in the response) + matches = re.findall( + # regex that matches any characters between last pair of `start_token` and `end_token`, and optionally allow missing `end_token` + rf"{start_token}(?!.*{start_token})(?:(.*){end_token}{"|(.*)" if allow_missing_end_token else ""})", + response, + flags=(re.IGNORECASE | re.DOTALL), + ) + if matches and any(matches[0]): + # there is always only a single match, and one group should be non-empty + if matches[0][0]: + evolved_prompt = matches[0][0] + else: + assert matches[0][1] + evolved_prompt = matches[0][1] + + else: + # could not extract prompt -> no new prompt + evolved_prompt = None + return evolved_prompt + @abstractmethod def update(self, *args, **kwargs): pass @@ -163,8 +198,8 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): if p_i is not None: prompt_source = ( "corrected" # could also mean that user skipped the prompt - if not all(j.happy for j in judgements) - else "generated" + if False in [j.happy for j in judgements] + else "evolution" ) evolved_prompt = self.add_prompt( p_i, @@ -195,7 +230,13 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): # Line 8: Return the best prompt, p∗, among the final population PT : # p∗ ↠argmaxp∈PT f(p, D) p = max(self.P[-1], key=lambda prompt: self.all_prompts[prompt.id].score) - logger.info("Best prompt with score %.2f: %s", p.score, p) + logger.info( + "Best prompt with score %.2f: %s (Source: %s - Gen: %d)", + p.score, + p, + p.meta["source"], + p.meta["gen"], + ) # We pick the prompt with the highest score on the development set and report its score on the testset. test_performance, _, _ = self.task.evaluate_test(p.content) @@ -241,20 +282,23 @@ class GeneticAlgorithm(EvolutionAlgorithm): # Based on this two-step process, we design instructions, guiding LLMs to # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. - filled_prompt = self.get_prompt_template().format( + prompt = self._get_prompt_template() + if isinstance(prompt, tuple) and len(prompt) > 1: + # extract demonstrations + prompt, demo_messages = prompt + filled_prompt = prompt.format( prompt1=prompt_1, prompt2=prompt_2, ) - evolved_prompt, history, recent_turn, usage = ( - self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - use_randomness=True, - ) + response, _, recent_turn, usage = self.evolution_model.create_completion( + system_message=SYSTEM_MESSAGE, + prompt=filled_prompt, + history=demo_messages if self.use_evolution_demo else None, + use_randomness=True, ) judgement = self.judge_and_correct_step( - filled_prompt, evolved_prompt, history, recent_turn + filled_prompt, response, history=None, recent_turn=recent_turn ) if judgement.skip: @@ -265,17 +309,18 @@ class GeneticAlgorithm(EvolutionAlgorithm): usage, ) - evolved_prompt = judgement.corrected_response - - if "<prompt>" in evolved_prompt: - evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] + response = judgement.corrected_response + evolved_prompt = self.parse_response(response) - logger.info( - "GA-evolved prompts '%s' and '%s' into '%s'", - prompt_1, - prompt_2, - evolved_prompt, - ) + if evolved_prompt is None: + logger.info(f"Could not extract prompt from response: {evolved_prompt}") + else: + logger.info( + "GA-evolved prompts '%s' and '%s' into '%s'", + prompt_1, + prompt_2, + evolved_prompt, + ) return evolved_prompt, [judgement], usage @@ -301,22 +346,28 @@ class GeneticAlgorithm(EvolutionAlgorithm): return retained_prompts - def get_prompt_template(self): + def _get_prompt_template(self): if self.use_evolution_demo: if isinstance( self.task, (TextClassification, Summarization, QuestionAnswering) ): - return get_demonstration_prompt_template( - GA_PROMPT, GA_DEMONSTRATION_DATA_SIM - ) + demonstration_data = GA_DEMONSTRATION_DATA_SIM elif isinstance(self.task, Simplification): - return get_demonstration_prompt_template( - GA_PROMPT, GA_DEMONSTRATION_DATA_CLS - ) + demonstration_data = GA_DEMONSTRATION_DATA_CLS else: raise NotImplementedError( f"Prompt with demonstration data is not implemented for task of type {type(self.task)}." ) + + prompt_with_demonstration_data = self.build_demonstration_prompt( + [ + ( + GA_PROMPT.format(**demonstration_data), + demonstration_data["response"], + ) + ] + ) + return GA_PROMPT, prompt_with_demonstration_data return GA_PROMPT @@ -342,59 +393,56 @@ class DifferentialEvolution(EvolutionAlgorithm): prompts_current_evolution, key=lambda prompt: prompt.score ) - filled_prompt = self.get_prompt_template().format( + prompt = self._get_prompt_template() + if isinstance(prompt, tuple) and len(prompt) > 1: + # extract demonstrations + prompt, demo_messages = prompt + filled_prompt = prompt.format( prompt1=prompt_1, prompt2=prompt_2, prompt3=best_prompt_current_evolution, basic_prompt=prompts_current_evolution[current_iteration], ) - evolved_prompt, history, recent_turn, usage = ( - self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - use_randomness=True, - ) + response, _, recent_turn, usage = self.evolution_model.create_completion( + system_message=SYSTEM_MESSAGE, + prompt=filled_prompt, + history=demo_messages if self.use_evolution_demo else None, + use_randomness=True, ) judgement = self.judge_and_correct_step( - filled_prompt, evolved_prompt, history, recent_turn + filled_prompt, response, history=None, recent_turn=recent_turn ) if judgement.skip: - # skip this prompt, for DE this means using the basic prompt + # user asked to skip this prompt, for DE this means using the basic prompt return ( prompts_current_evolution[current_iteration].content, [judgement], usage, ) - evolved_prompt = judgement.corrected_response + response = judgement.corrected_response + evolved_prompt = self.parse_response(response) - matches = re.findall( - # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing - r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))", - evolved_prompt, - flags=(re.IGNORECASE | re.DOTALL), - ) - if matches and any(matches[0]): - # there is always only a single match, and one group should be non-empty - if matches[0][0]: - evolved_prompt = matches[0][0] - else: - assert matches[0][1] - evolved_prompt = matches[0][1] + if evolved_prompt is None: + logger.info(f"Could not extract prompt from response: {evolved_prompt}") + + # no prompt was returned (e.g., evolved prompt could not be extracted), therefore, for DE, we use the basic prompt + return ( + prompts_current_evolution[current_iteration].content, + [judgement], + usage, + ) else: - # TODO what to do in this case? Discard generated prompt directly? - pass - - logger.info( - "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", - prompt_1, - prompt_2, - best_prompt_current_evolution, - prompts_current_evolution[current_iteration], - evolved_prompt, - ) + logger.info( + "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", + prompt_1, + prompt_2, + best_prompt_current_evolution, + prompts_current_evolution[current_iteration], + evolved_prompt, + ) return evolved_prompt, [judgement], usage @@ -412,22 +460,28 @@ class DifferentialEvolution(EvolutionAlgorithm): ] return population - def get_prompt_template(self): + def _get_prompt_template(self): if self.use_evolution_demo: if isinstance( self.task, (TextClassification, Summarization, QuestionAnswering) ): - return get_demonstration_prompt_template( - DE_PROMPT, DE_DEMONSTRATION_DATA_SIM - ) + demonstration_data = DE_DEMONSTRATION_DATA_SIM elif isinstance(self.task, Simplification): - return get_demonstration_prompt_template( - DE_PROMPT, DE_DEMONSTRATION_DATA_CLS - ) + demonstration_data = DE_DEMONSTRATION_DATA_CLS else: raise NotImplementedError( f"Prompt with demonstration data is not implemented for task of type {type(self.task)}." ) + + prompt_with_demonstration_data = self.build_demonstration_prompt( + [ + ( + DE_PROMPT.format(**demonstration_data), + demonstration_data["response"], + ) + ] + ) + return DE_PROMPT, prompt_with_demonstration_data return DE_PROMPT @@ -453,40 +507,64 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): prompts_current_evolution, key=lambda prompt: prompt.score ) - history = None - response: str = "" + # list of evolution steps + evolutions_steps = [] + # list (turns) of list (demonstrations) + demos = [[]] judgements: list[Judgement] = [] usage: ModelUsage = ModelUsage() - for idx, prompt in enumerate(self.get_prompt_template()): + for idx, prompt in enumerate(self._get_prompt_template()): + if isinstance(prompt, tuple) and len(prompt) > 1: + # extract demonstrations + prompt, demo_messages = prompt + demos[-1].extend(demo_messages) + if self.use_evolution_demo: + messages_demos = self.condense_messages(demos[-1]) + else: + messages_demos = [] filled_prompt = prompt.format( prompt1=prompt_1, prompt2=prompt_2, prompt3=best_prompt_current_evolution, basic_prompt=prompts_current_evolution[current_iteration], ) + evolutions_steps.append( + self.evolution_model._get_user_message(filled_prompt) + ) + # TODO Shall we still use only a single turn containing all messages if we do not use demonstrations for evolution? + prompt = self.condense_messages(evolutions_steps, return_str=True) response, history, recent_turn, usage = ( self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - history=history, + prompt=prompt, + history=messages_demos, # the models often repeat the instuction which could also contain </prompt> therefore we should not stop early - stop=None, # "</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None, + stop=None, use_randomness=True, ) ) + evolutions_steps.append( + self.evolution_model._get_assistant_message(response) + ) logger.debug( "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s", idx, history + recent_turn, response, ) + # TODO use serialized messages as prompt or use previous evolution steps as history? + input(f"{len(evolutions_steps)}, \n{evolutions_steps}") judgement = self.judge_and_correct_step( - filled_prompt, response, history=history, recent_turn=recent_turn + filled_prompt, + response, + history=evolutions_steps[:-2], + recent_turn=recent_turn, + # prompt, response, history=None, recent_turn=recent_turn ) judgements.append(judgement) if judgement.skip: - # skip this prompt, for DE this means using the basic prompt + # user asked to skip this prompt, for DE this means using the basic prompt return ( prompts_current_evolution[current_iteration].content, judgements, @@ -494,28 +572,62 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): ) # replace last message with corrected response - recent_turn[-1]["content"] = judgement.corrected_response response = judgement.corrected_response # update history with recent turn history += recent_turn + history.append(self.evolution_model._get_assistant_message(response)) - # at this point we should get a new prompt - if "<prompt>" in response: - response = response.split("<prompt>")[1].split("</prompt>")[0] + evolved_prompt = self.parse_response(response) - logger.info( - "DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", - prompt_1, - prompt_2, - best_prompt_current_evolution, - prompts_current_evolution[current_iteration], - response, - ) + if evolved_prompt is None: + logger.info(f"Could not extract prompt from response: {evolved_prompt}") + + # no prompt was returned (e.g., evolved prompt could not be extracted), therefore, for DE, we use the basic prompt + return ( + prompts_current_evolution[current_iteration].content, + judgements, + usage, + ) + else: + logger.info( + "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", + prompt_1, + prompt_2, + best_prompt_current_evolution, + prompts_current_evolution[current_iteration], + evolved_prompt, + ) + + return evolved_prompt, judgements, usage + + def condense_messages( + self, messages: list[ChatMessages], return_str: bool = False + ) -> list[dict] | str: + if not messages: + if return_str: + return "" + return [] - return response, judgements, usage + if messages[-1]["role"] == "assistant": + assistant_turn = messages[-1] + messages = messages[:-1] + else: + assistant_turn = None + + user_turn = "\n\n".join(message["content"] for message in messages) + if return_str: + assert ( + assistant_turn is None + ), "Cannot return string if most recent turn is from assistant." + return user_turn + + messages = [self.evolution_model._get_user_message(user_turn)] + if assistant_turn is not None: + messages.append(assistant_turn) + return messages - def get_prompt_template(self): + def _get_prompt_template(self): if self.use_evolution_demo: if isinstance( self.task, (TextClassification, Summarization, QuestionAnswering) @@ -528,9 +640,18 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): f"Prompt with demonstration data is not implemented for task of type {type(self.task)}." ) - for prompt, demonstration_data_item in zip( + for prompt_template, demonstration_data_item in zip( DE_COT_PROMPTS, demonstration_data ): - yield get_demonstration_prompt_template(prompt, demonstration_data_item) + prompt_with_demonstration_data = self.build_demonstration_prompt( + [ + ( + prompt_template.format(**demonstration_data_item), + demonstration_data_item["response"], + ) + ] + ) + yield prompt_template, prompt_with_demonstration_data + # TODO how can we add a generation prefix for the model? else: yield from DE_COT_PROMPTS diff --git a/evoprompt/evolution/template.py b/evoprompt/evolution/template.py deleted file mode 100644 index a84dbb3f5526a5e63d0f7a19ab8e8053c86081fd..0000000000000000000000000000000000000000 --- a/evoprompt/evolution/template.py +++ /dev/null @@ -1,6 +0,0 @@ -def get_demonstration_prompt_template(prompt_template: str, demonstration_data: dict): - prompt_template_with_demo = prompt_template.format(**demonstration_data) - prompt_template_with_demo += "\n\n" + demonstration_data["response"] - prompt_template_with_demo += "\n\n" + prompt_template - prompt_template_with_demo += "\n\n" + demonstration_data["generation_prefix"] - return prompt_template_with_demo diff --git a/evoprompt/models.py b/evoprompt/models.py index 8653c0d5cbda4b9cebdbb07b556d50efdcbc08af..cb7c4ff342cd05cb723dacd8617fccaff830d246 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,8 +1,12 @@ +from collections.abc import Iterable import hashlib import inspect +from itertools import zip_longest import json import logging import random +import sqlite3 +import time import warnings from abc import ABC, abstractmethod from argparse import ArgumentParser @@ -23,7 +27,8 @@ logger = logging.getLogger(__name__) logging.captureWarnings(True) warnings.simplefilter("once") -ChatMessages = list[dict[str, str]] +ChatMessage = dict[str, str] +ChatMessages = list[ChatMessage] class LLMModel(ABC): @@ -57,17 +62,29 @@ class LLMModel(ABC): LLMModel.loaded_models[cls] = (model, key) return model - def __init__(self, ignore_cache_kwargs: list[str] | None = None, **kwargs): + def __init__( + self, cache_kwargs: dict, ignore_cache_kwargs: list[str] | None = None, **kwargs + ): self.usage = ModelUsage() # store kwargs for caching - self.kwargs = kwargs.copy() - self.kwargs["ignore_cache_kwargs"] = ignore_cache_kwargs + self.cache_kwargs = cache_kwargs.copy() + self.cache_kwargs["ignore_cache_kwargs"] = ignore_cache_kwargs + + self.max_tokens = kwargs.get("max_tokens", None) # set up caching for model calls self._call_model_cached = None - if not self.kwargs.get("disable_cache", False): - cache = Cache(Path(".cache_dir", self.model_cache_key)) + if kwargs.get("disable_cache", False): + logger.info("Caching is disabled") + else: + while True: + try: + cache = Cache(Path(".cache_dir", self.model_cache_key)) + break + except sqlite3.OperationalError: + logger.warning("Failed to open cache, retrying...", exc_info=True) + time.sleep(random.random() * 5) @cache.memoize(typed=True, ignore=[0, "func"]) def _call_function(func, *args, **kwargs): @@ -84,8 +101,10 @@ class LLMModel(ABC): grammar: llama_cpp.LlamaGrammar | None = None, stop: str | None = None, use_randomness: bool = False, + temperature: float | None = None, **kwargs, ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]: + raise NotImplementedError("Non-chat models are currently not supported") messages = [self._get_user_message(prompt)] if system_message is not None: prompt = system_message + prompt @@ -95,13 +114,42 @@ class LLMModel(ABC): grammar=grammar, stop=stop, use_cache=use_cache, - max_tokens=self.kwargs.get("max_tokens", None), + max_tokens=self.max_tokens, use_randomness=use_randomness, ) - messages.append(self._get_assistant_message(reponse)) return reponse, None, messages, usage + def build_demonstration_data( + self, + demonstrations: Iterable[tuple[str, str]], + instruction: list[str] | str | None, + **kwargs, + ) -> ChatMessages: + if not isinstance(self, ChatModel): + raise ValueError( + f"Model {self} does not support building demonstration data" + ) + + if not isinstance(instruction, list): + instruction = [instruction] + messages = [] + for (input_, output), _instruction in zip_longest( + demonstrations, instruction, fillvalue=instruction[-1] + ): + messages.extend( + self.build_input_data(input_, instruction=_instruction, **kwargs)[1] + ) + messages.append(self._get_assistant_message(output)) + return messages + + def build_input_data( + self, input_: str, instruction: str | None = None, **kwargs + ) -> ChatMessages: + return instruction, [ + self._get_user_message(input_ if input_ is not None else instruction) + ] + def _get_prediction_prefix(self): # some models use a special token prefix for the prediction return None @@ -120,21 +168,22 @@ class LLMModel(ABC): stop: str | None = None, max_tokens: int | None = None, use_randomness: bool = False, + temperature: float | None = None, ): ... - def _get_user_message(self, content: Any): + def _get_user_message(self, content: Any) -> ChatMessage: return { "role": "user", "content": content, } - def _get_system_message(self, content: Any): + def _get_system_message(self, content: Any) -> ChatMessage: return { "role": "system", "content": content, } - def _get_assistant_message(self, content: Any): + def _get_assistant_message(self, content: Any) -> ChatMessage: return { "role": "assistant", "content": content, @@ -183,7 +232,7 @@ class LLMModel(ABC): return ( str(self.model_name).replace("/", "_") + "/" - + self.get_hash_from_kwargs(**self.kwargs) + + self.get_hash_from_kwargs(**self.cache_kwargs) ) @property @@ -255,7 +304,11 @@ class Llama(LLMModel): # pass all arguments to super constructor which should be taken into account for caching # needs to be called after model is initialized - super().__init__(ignore_cache_kwargs=ignore_cache_kwargs, **hashed_model_kwargs) + super().__init__( + cache_kwargs=hashed_model_kwargs, + ignore_cache_kwargs=ignore_cache_kwargs, + **kwargs, + ) def _create_completion( self, @@ -266,6 +319,7 @@ class Llama(LLMModel): stop: str | None = None, max_tokens: int | None = None, use_randomness: bool = False, + temperature: float | None = None, ): # setup kwargs for model call model_call_kwargs = { @@ -275,8 +329,11 @@ class Llama(LLMModel): "max_tokens": max_tokens, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature model_call_kwargs["seed"] = random.randint(0, 2**32 - 1) else: model_call_kwargs["temperature"] = 0.0 @@ -339,40 +396,50 @@ class ChatModel: @weave.op() def create_completion( self, - system_message: str | None, - prompt: str, + system_message: str | ChatMessage | None, + messages: ChatMessages | None = None, *, + prompt: str | None = None, use_cache: bool = False, grammar: llama_cpp.LlamaGrammar | None = None, stop: str | None = None, history: ChatMessages | None = None, use_randomness: bool = False, + temperature: float | None = None, **kwargs, ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]: - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message in the history? - messages = [self._get_user_message(prompt)] + if messages is None: + assert prompt is not None, "Either messages or prompt must be provided" + messages = [self._get_user_message(prompt)] - if history is None: - if system_message: - history = [self._get_system_message(system_message)] - else: - history = [] # we prepend the history to the messages # the chat format should take care of adding appropriate assistant messages for generating the completion - messages_for_model = history + messages + if history is None: + messages_for_model = messages + else: + messages_for_model = history + messages + # prepend system message if available + if system_message is not None: + if isinstance(system_message, str): + system_message = self._get_system_message(system_message) + messages_for_model = [system_message] + messages_for_model reponse, usage = self._create_completion( messages=messages_for_model, grammar=grammar, stop=stop, use_cache=use_cache, - max_tokens=self.kwargs.get("max_tokens", None), + max_tokens=self.max_tokens, use_randomness=use_randomness, + temperature=temperature, ) - messages.append(self._get_assistant_message(reponse)) - return reponse, history, messages, usage + return ( + reponse, + history, + messages, + usage, + ) class LlamaChat(ChatModel, Llama): @@ -387,9 +454,10 @@ class LlamaChat(ChatModel, Llama): stop: str | None, max_tokens: int | None, use_randomness: bool, + temperature: float | None = None, ): # input( - # f"The input for the model will look like this:\n{format_llama3(messages).prompt}" + # f"The input for a Llama3.x model will look like this:\n{format_llama3(messages).prompt}" # ) # setup kwargs for model call model_call_kwargs = { @@ -399,8 +467,11 @@ class LlamaChat(ChatModel, Llama): "max_tokens": max_tokens, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature model_call_kwargs["seed"] = random.randint(0, 2**32 - 1) else: model_call_kwargs["temperature"] = 0.0 @@ -446,7 +517,9 @@ class HfChat(ChatModel, LLMModel): ignore_cache_kwargs.extend(["torch_dtype"]) # pass all arguments to super constructor which should be taken into account for caching - super().__init__(**model_kwargs, ignore_cache_kwargs=ignore_cache_kwargs) + super().__init__( + cache_kwargs=model_kwargs, ignore_cache_kwargs=ignore_cache_kwargs, **kwargs + ) # initialize model self.pipeline = transformers.pipeline( @@ -471,23 +544,28 @@ class HfChat(ChatModel, LLMModel): stop: str | None, max_tokens: int | None, use_randomness: bool, + temperature: float | None = None, **kwargs, ): # setup kwargs for model call model_call_kwargs = { "text_inputs": messages, "stop": stop, - "max_length": max_tokens if max_tokens is not None else 2048, + # "max_length": max_tokens if max_tokens is not None else 16384, + "max_new_tokens": 1000, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature model_call_kwargs["do_sample"] = True else: model_call_kwargs["do_sample"] = False # input( - # f"The input for the model will look like this:\n{self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False)}" + # f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'" # ) response = self._call_model( self.pipeline, @@ -506,41 +584,39 @@ class HfChat(ChatModel, LLMModel): # For Alpaca we build inputs to follow the fine-tuning chat format like https://github.com/tatsu-lab/stanford_alpaca/blob/761dc5bfbdeeffa89b8bff5d038781a4055f796a/README.md?plain=1#L56-L66 class AlpacaHfChat(HfChat): + SYSTEM_MESSAGE = "You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." + def __init__( self, model: str = "chavinlo/alpaca-native", - ignore_cache_kwargs: list[str] | None = None, **kwargs, ): - super().__init__(model=model, ignore_cache_kwargs=ignore_cache_kwargs, **kwargs) + super().__init__(model=model, **kwargs) # chat template for Alpaca adapted from https://huggingface.co/Vezora/Mistral-22B-v0.1/blob/c15d70465e2fc46c3c4d7fec8fb62f533d4ef09b/tokenizer_config.json#L30 - self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message + '\\n\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip()}}{% endif %}{% endfor %}" - - def _create_completion( - self, - messages: list[dict[str, str]], - *, - use_cache: bool, - stop: str | None, - max_tokens: int | None, - use_randomness: bool, - **kwargs, - ): - # for some reason adding an empty assistant message yields different generations than adding it manually in the chat template - return super()._create_completion( - messages + [self._get_assistant_message("")], - use_cache=use_cache, - stop=stop, - max_tokens=max_tokens, - use_randomness=use_randomness, - **kwargs, - ) + self.pipeline.tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ system_message }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '\\n\\n### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip()}}{% endif %}{% endfor %}" + + def build_input_data( + self, input_: str, instruction: str = None, **kwargs + ) -> ChatMessages: + # For Alpaca we add the instruction for each input (and assume that content is the context, otherwise the content should contain the instruction) + if instruction is None: + assert ( + input_ is not None + ), "Either instruction or input must be provided for Alpaca" + prompt_input = input_ + else: + prompt_input = instruction + if input_ is not None: + prompt_input += "\n\n" + self._get_input_prefix() + input_ + return self.SYSTEM_MESSAGE, [self._get_user_message(prompt_input)] - def _get_input_prefix(self): + @staticmethod + def _get_input_prefix(): return "### Input:\n" - def _get_prediction_prefix(self): + @staticmethod + def _get_prediction_prefix(): # Alpaca uses a special token prefix for the prediction return "\n### Response:\n" @@ -559,7 +635,11 @@ class OpenAIChat(ChatModel, LLMModel): self.openai_client = openai.OpenAI() self._model_name = openai_model - super().__init__(model=openai_model, ignore_cache_kwargs=ignore_cache_kwargs) + super().__init__( + cache_kwargs=dict(model=openai_model), + ignore_cache_kwargs=ignore_cache_kwargs, + **kwargs, + ) def _create_completion( self, @@ -569,6 +649,7 @@ class OpenAIChat(ChatModel, LLMModel): stop: str | None, max_tokens: int | None, use_randomness: bool, + temperature: float | None = None, ): # setup kwargs for model call model_call_kwargs = { @@ -578,8 +659,11 @@ class OpenAIChat(ChatModel, LLMModel): "max_completion_tokens": max_tokens if max_tokens is not None else 1024, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature else: model_call_kwargs["temperature"] = 0.0 diff --git a/evoprompt/opt_types.py b/evoprompt/opt_types.py index a80c850ff7032e393cfbfc73c072769f857a670d..b700da47b5540073e8f471d3f74c7a423e2556ac 100644 --- a/evoprompt/opt_types.py +++ b/evoprompt/opt_types.py @@ -1,5 +1,6 @@ import json from dataclasses import dataclass, field, is_dataclass +from typing import Literal, NamedTuple, TypedDict from uuid import uuid4 @@ -25,6 +26,29 @@ class ModelUsage: ) +PromptSource = Literal[ + "baseprompt", + "baseprompt_file", + "baseprompt_gen", + "paraphrase", + "evolution", + "corrected", +] + + +class Judgement(NamedTuple): + original_response: str + corrected_response: str + happy: bool | None + skip: bool + + +class PromptMeta(TypedDict): + gen: int + source: PromptSource + judgements: list[Judgement] + + @dataclass(frozen=True) class Prompt: content: str diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index c1147f40f3b1a1c7c0df1ed0358012149c80acc6..cc3710aca89103a63f28fc43768ec23387a2d978 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict from difflib import Differ from pathlib import Path -from typing import Any, Literal, NamedTuple, Optional, TypedDict +from typing import Any, Optional import wandb from rich.panel import Panel @@ -18,7 +18,13 @@ from tqdm import tqdm, trange from evoprompt.cli import argument_parser from evoprompt.models import ChatMessages, LLMModel -from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt +from evoprompt.opt_types import ( + Judgement, + ModelUsage, + OptTypeEncoder, + Prompt, + PromptMeta, +) from evoprompt.task import Task from evoprompt.utils import initialize_run_directory, log_calls @@ -27,21 +33,6 @@ logger = logging.getLogger(__name__) PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.""" -PromptSource = Literal["baseprompt", "paraphrase", "evolution", "corrected"] - - -class Judgement(NamedTuple): - original_response: str - corrected_response: str - happy: bool | None - skip: bool - - -class PromptMeta(TypedDict): - gen: int - source: PromptSource - judgements: list[Judgement] - class ResponseEditor(App): BINDINGS = [ @@ -56,27 +47,41 @@ class ResponseEditor(App): instruction: str, original_response: str, history: ChatMessages, + recent_turn: ChatMessages, judge_response: str, ): self.instruction = instruction self.response = original_response self.history = history + self.recent_turn = recent_turn self.judge_response = judge_response self.skip = False # used to mark the prompt as skipped super().__init__() def compose(self) -> ComposeResult: self.text_area = TextArea.code_editor(self.response, soft_wrap=True) + if self.history is not None: + yield ScrollableContainer( + *( + Collapsible( + Static(message["content"]), + title=message["role"], + collapsed=True, + ) + for message in self.history + ) + ) yield ScrollableContainer( *( Collapsible( Static(message["content"]), title=message["role"], - collapsed=idx != len(self.history) - 1, + collapsed=False, ) - for idx, message in enumerate(self.history) + for message in self.recent_turn ) ) + yield ScrollableContainer( Label(Panel(self.judge_response, title="Judge response")), Label(Rule(title="Response to edit"), expand=True), @@ -153,13 +158,11 @@ class PromptOptimization: *, task: Task, evolution_model: LLMModel, - evaluation_model: LLMModel, judge_model: LLMModel | None, run_options: dict[str, Any] = {}, ) -> None: self.task = task self.evolution_model = evolution_model - self.evaluation_model = evaluation_model self.judge_model = judge_model self.run_options = run_options @@ -174,8 +177,8 @@ class PromptOptimization: return self.task.evaluate_validation(prompt, parent_histories) def get_initial_prompts(self, num_initial_prompts: int, debug: bool = False): - # this implements the para_topk algorothm from https://github.com/beeevita/EvoPrompt - base_prompts = self.task.base_prompts + # this implements the para_topk algorithm from https://github.com/beeevita/EvoPrompt + base_prompts, base_prompts_sources = self.task.base_prompts if debug: base_prompts = base_prompts[:2] @@ -188,15 +191,14 @@ class PromptOptimization: # take at most half of the best prompts sorted_results = sorted( - zip(evaluation_results, base_prompts), + zip(evaluation_results, base_prompts, base_prompts_sources), key=lambda x: x[0][0], # sort by score reverse=True, # best first ) - top_prompts = [ - prompt for _, prompt in sorted_results[: num_initial_prompts // 2] - ] - initial_population = top_prompts.copy() - prompt_sources = ["baseprompt" for _ in initial_population] + sorted_results = sorted_results[: num_initial_prompts // 2] + _, top_prompts, prompt_sources = zip(*sorted_results) + initial_population = list(top_prompts) + prompt_sources = list(prompt_sources) # fill up the rest with paraphrases of the top prompts promptindex_to_paraphrase = 0 @@ -349,7 +351,7 @@ class PromptOptimization: self, instruction: str, response: str, - history: ChatMessages, + history: ChatMessages | None, recent_turn: ChatMessages, ) -> Judgement: # TODO potentially move to separate class wrapping the judge model and related functionality @@ -359,27 +361,37 @@ class PromptOptimization: # judge the actual response # concatenate all user and assistant messages to provide context - history_str = "\n".join( - message["content"] - for message in history - if message["role"] in ["user", "assistant"] - ) - # TODO What if the history does not exist (is empty), i.e., for the first step in de-cot? - prompt = ( - f"Context: {history_str}\nInstruction: {instruction}\nResponse: {response}" - ) - system_message = ( - "You are acting as a judge. Please read the context, the instruction and the response " - "and decide if the response follows the instruction. " - "If it does, answer 'good'. If it does not, answer 'bad'. " - "Wrap the answer with tags <judgement> and </judgement>. " - "Please also add an explanation for your judgement." - ) + # if there is no history, only show instruction and response + if history: + history_str = "\n".join( + message["content"] + for message in history + if message["role"] in ["user", "assistant"] + ) + prompt = f"Context:\n{history_str}\n\nInstruction:\n{instruction}\n\nResponse:\n{response}" + system_message = ( + "You are acting as a judge. Please read the context, the instruction and the response " + "and decide if the response follows the instruction. " + "If it does, answer 'good'. If it does not, answer 'bad'. " + "Wrap the answer with tags <judgement> and </judgement>. " + "Please also add an explanation for your judgement." + ) + else: + prompt = f"Instruction:\n{instruction}\n\nResponse:\n{response}" + system_message = ( + "You are acting as a judge. Please read the instruction and the response " + "and decide if the response follows the instruction. " + "If it does, answer 'good'. If it does not, answer 'bad'. " + "Wrap the answer with tags <judgement> and </judgement>. " + "Please also add an explanation for your judgement." + ) + # input(f"System message:\n{system_message}\n\nPrompt:\n{prompt}\n") judgement_response, _, _, _ = self.judge_model.create_completion( system_message=system_message, prompt=prompt, ) + # input(f"Judgement response:\n{judgement_response}\n") matches = re.findall( # regex that matches `good` and `bad` between <judgement> and </judgement> where additional characters can be present, e.g., whitespace r"<judgement>.*(good|bad).*</judgement>", @@ -409,7 +421,8 @@ class PromptOptimization: editor = ResponseEditor( instruction, response, - history[:-1], + history if history is not None else None, + recent_turn=recent_turn, judge_response=judgement_response, ) editor.run() @@ -420,13 +433,12 @@ class PromptOptimization: delta = Differ().compare( response.splitlines(), editor.modified_response.splitlines() ) + delta = [ + line for line in delta if line.startswith("+") or line.startswith("-") + ] logger.info( - "User corrected prompt (delta):\n%s", - "\n".join( - line - for line in delta - if line.startswith("+") or line.startswith("-") - ), + "User corrected prompt (delta):%s", + ("\n" + "\n".join(delta)) if delta else " No changes", # "User corrected prompt:\n'%s'\n -> \n'%s'", # response, # editor.modified_response, diff --git a/evoprompt/task/base_prompts_mixin.py b/evoprompt/task/base_prompts.py similarity index 53% rename from evoprompt/task/base_prompts_mixin.py rename to evoprompt/task/base_prompts.py index c4d00cd99d71788623ce4b6d1e67e5b5dfd14e60..2eb9ad713c1be95c8e7278331a5427a2401e2b12 100644 --- a/evoprompt/task/base_prompts_mixin.py +++ b/evoprompt/task/base_prompts.py @@ -10,7 +10,7 @@ from evoprompt.utils import get_rng class BasePromptsFromJsonMixin: @staticmethod - def _load_json_file(path: str): + def _load_json_file(path: str) -> list[str]: with Path(path).open() as json_file: return json.load(json_file) @@ -20,41 +20,48 @@ class BasePromptsFromJsonMixin: raise Exception( f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`." ) - base_prompts = [] + prompts, sources = super().base_prompts + prompts_from_files = [] for prompt_file in self.base_prompts_files: - base_prompts += self._load_json_file(prompt_file) - return base_prompts + prompts_from_files += self._load_json_file(prompt_file) + prompts += prompts_from_files + sources += ["baseprompt_file"] * len(prompts_from_files) + return prompts, sources -class BasePromptsFromGeneration: +class BasePromptsFromGenerationMixin: def __init__(self, *args, **kwargs) -> None: self.evolution_model: LLMModel = kwargs.get("evolution_model") super().__init__(*args, **kwargs) # this implements the initial population generation from Zhou et al., 2023: Large Language Models are Human-Level Prompt Engineers + # patience allows to stop the generation process if no new prompts can be generated + # can be set to -1 to generate as many prompts as needed (but can possibly run forever) def generate_prompt( self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False - ) -> str: + ) -> list[str]: self.validation_dataset: Dataset samples = self.validation_dataset.shuffle(42).select( get_rng().choice(len(self.validation_dataset), 5, replace=False) ) - prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n" - raise NotImplementedError( - "The prompt needs to be adapted for the model taking into account the correct format." + prompt = "I gave a friend a single instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n\n" + prompt += "\n".join( + f"Input:\n{self._get_prompt_text_for_datum(sample)}\nOutput:\n{self._get_gold_label_generation_for_datum(sample)}\n" + for sample in samples ) - prompt = self.build_demonstration_prompt(samples, prompt=prompt) prompt += "\nThe instruction was " system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs." - input(prompt) + messages = [ + self.evolution_model._get_user_message(prompt) + ] # , self.evolution_model._get_assistant_message("The instruction was ")] generated_prompts = [] while len(generated_prompts) < num_prompts: response, _, _, _ = self.evolution_model.create_completion( system_message=system_message, - prompt=prompt, + messages=messages, + use_randomness=True, ) - input(response) matches = re.findall( # regex that extracts anything within tags <instruction> and optional </instruction> rf"<instruction>(.+?)(?:(?=</instruction>)|$)", @@ -62,9 +69,9 @@ class BasePromptsFromGeneration: flags=re.IGNORECASE, ) if matches: - prompt = matches[-1].strip() - if allow_duplicates or prompt not in generated_prompts: - generated_prompts.append(matches[-1].strip()) + generated_prompt = matches[-1].strip() + if allow_duplicates or generated_prompt not in generated_prompts: + generated_prompts.append(generated_prompt) else: if patience == 0: break @@ -74,6 +81,15 @@ class BasePromptsFromGeneration: @property def base_prompts(self): - num_prompts = getattr(self, "num_generated_base_prompts", 0) + if not hasattr(self, "num_generated_base_prompts"): + raise AttributeError( + f"{self.__class__} must expose attribute `num_generated_base_prompts`" + ) + prompts, sources = super().base_prompts + + num_prompts = self.num_generated_base_prompts + generated_prompts = self.generate_prompt(num_prompts) + prompts += generated_prompts + sources += ["baseprompt_gen"] * len(generated_prompts) - return self.generate_prompt(num_prompts) + return prompts, sources diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index 6266c86d03b879424276242f2534f77d1cbcf589..fddf81a4473f4047f08eb39f2dc289a487fc43ee 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -8,7 +8,7 @@ from datasets import Dataset from evaluate import load as load_metric from llama_cpp import LlamaGrammar -from evoprompt.task.base_prompts_mixin import BasePromptsFromGeneration +from evoprompt.task.base_prompts import BasePromptsFromGenerationMixin from evoprompt.opt_types import ModelUsage from evoprompt.task.task import DatasetDatum, Task from evoprompt.utils import get_rng @@ -172,7 +172,7 @@ class QuestionAnswering(Task): return "f1" -class SQuAD(BasePromptsFromGeneration, QuestionAnswering): +class SQuAD(BasePromptsFromGenerationMixin, QuestionAnswering): shorthand = "squad" num_generated_base_prompts = 10 @@ -202,7 +202,7 @@ class SQuAD(BasePromptsFromGeneration, QuestionAnswering): @property def base_prompts(self): - generated_base_prompts = super().base_prompts - return [ - "In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context." - ] + generated_base_prompts + prompts, sources = super().base_prompts + prompts.append("In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.") + sources.append("baseprompt") + return prompts, sources diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index bf72c175ed7bf2f6e4ad763c5ef69f0c72ed125b..89b092a627bacb935f81bd44555c0d436b6b0987 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -4,7 +4,7 @@ from typing import Mapping from datasets import load_dataset -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 37e0ec50f36f829bd01d3c8db290ac2db0c810de..3a2e6aafad86d558824e783aec3654674fb49e68 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -2,7 +2,7 @@ import logging from evaluate import load as load_metric -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 7c3882e0a96a9c89db131220df97e7aeecfa15ab..10fea8b565e9f9bea017abf5967fa98deb385c18 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -3,7 +3,7 @@ from typing import Mapping from datasets import load_dataset -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index fed21c15fbe2a409bfb2f4866242b96b4e48b16c..4084d277e76acc3f2ac82f409383976d7cf10f98 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -2,7 +2,7 @@ import logging from evaluate import load as load_metric -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 6b2a3344d2f3bf8f30787fb97c247282090077e7..90f128b729f7c2aa96d9831943fbea7d91fb8e53 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -10,15 +10,13 @@ from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar from tqdm import tqdm -from evoprompt.models import LLMModel -from evoprompt.opt_types import ModelUsage +from evoprompt.models import ChatMessage, ChatMessages, LLMModel +from evoprompt.opt_types import ModelUsage, PromptSource from evoprompt.utils import log_calls logger = logging.getLogger(__name__) -SYSTEM_MESSAGE = "You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." - DatasetDatum = dict @@ -344,7 +342,6 @@ class Task(metaclass=ABCMeta): ) -> Iterable[int]: pass - @log_calls("Evaluating validation dataset") def evaluate_validation( self, prompt: str, parent_histories: list[list[float]] | None = None @@ -371,11 +368,15 @@ class Task(metaclass=ABCMeta): evaluation_history = [] # augment prompt with demonstration samples - prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt) + demonstration_prompt = self.build_demonstration_prompt( + self.demonstration_samples, instruction=prompt + ) for datum in dataset_iterator: # run prediction - response, usage = self.predict(prompt=prompt_with_examples, datum=datum) + response, usage = self.predict( + instruction=prompt, history=demonstration_prompt, datum=datum + ) logger.debug(f"Response: '{response}'") # parse response response = self._parse_response(response=response) @@ -399,17 +400,22 @@ class Task(metaclass=ABCMeta): break return self._aggregate_result(results), evaluation_usage, evaluation_history - + @weave.op() - def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: + def predict( + self, instruction: str, history: ChatMessages, datum: DatasetDatum + ) -> tuple[str, ModelUsage]: # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? # build prompt for current sample - prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None) + system_message, prompt_for_datum = self.build_prompt_input( + datum, instruction, use_prediction_prefix=self.force_task_prediction_prefix + ) logger.debug(f"Prompt for datum:\n{prompt_for_datum}") response, _, _, usage = self.model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=prompt_for_datum, + system_message=system_message, + messages=prompt_for_datum, + history=history, # grammar can be applied to constrain the model output grammar=self._get_grammar(datum) if self.use_grammar else None, # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result @@ -423,28 +429,43 @@ class Task(metaclass=ABCMeta): response = response.strip() return response, usage - + def build_prompt_input( - self, sample, prompt: str = "", use_prediction_prefix: bool = False, - ) -> str: + self, + sample: DatasetDatum, + instruction: str, + use_prediction_prefix: bool = False, + ) -> tuple[ChatMessage, ChatMessages]: # the default is to use the prompt as is and concatenate the datum string - prompt += f"\n\n{self.model._get_input_prefix() if self.model._get_input_prefix() is not None else ""}{self._get_prompt_text_for_datum(sample, use_prefix=self.force_task_input_prefix or not self.model._get_input_prefix())}" + datum_input = self._get_prompt_text_for_datum( + sample, use_prefix=self.force_task_input_prefix + ) if use_prediction_prefix: - prompt += f"\n{self._get_prediction_prefix().strip()} " - return prompt.strip() - + datum_input += f"\n{self._get_prediction_prefix().strip()}" + return self.model.build_input_data(datum_input, instruction) + def build_demonstration_prompt( self, - demonstration_samples: list[dict], - prompt: str = "", - ) -> str: - for sample in demonstration_samples: - prompt += "\n\n" + self.build_prompt_input(sample) - prompt += f"\n{self.model._get_prediction_prefix() if self.model._get_prediction_prefix() is not None else self._get_prediction_prefix()}{self._get_gold_label_generation_for_datum(sample)}" - return prompt.strip() - + demonstration_samples: Iterable[DatasetDatum], + instruction: str = None, + ) -> ChatMessages: + return self.model.build_demonstration_data( + [ + ( + self._get_prompt_text_for_datum( + sample, use_prefix=self.force_task_input_prefix + ), + self._get_gold_label_generation_for_datum(sample), + ) + for sample in demonstration_samples + ], + instruction=instruction, + ) + @abstractmethod - def _get_prompt_text_for_datum(self, datum: DatasetDatum, use_prefix: bool = False) -> str: ... + def _get_prompt_text_for_datum( + self, datum: DatasetDatum, use_prefix: bool = False + ) -> str: ... @abstractmethod def _get_prediction_prefix() -> str: ... @@ -471,4 +492,5 @@ class Task(metaclass=ABCMeta): @property @abstractmethod - def base_prompts(self) -> list[str]: ... + def base_prompts(self) -> tuple[list[str], list[PromptSource]]: + return [], [] diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py index af30f1545b21f5e43ff674c89a3305abb0d9153f..be67da9da44f0c8a7ab0ef4750553866175989b7 100644 --- a/evoprompt/task/text_classification.py +++ b/evoprompt/task/text_classification.py @@ -38,7 +38,7 @@ class TextClassification(Task): flags=re.IGNORECASE, ) if matches: - return matches[-1] + return matches[0] else: return response diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index dd1905f1d46ee0ff1d773aaba36a78c12a4cd7d4..6230f85c98c1551fdcfb95c8b54d9b4bd79bc050 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -3,7 +3,7 @@ from typing import Mapping from datasets import load_dataset -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/utils.py b/evoprompt/utils.py index 701ae38dc558f789e2ae0eef23e62aac398a1bd1..8db75aa63604d0d4392b107708cca3b67b666244 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -57,7 +57,7 @@ def setup_console_logger(verbosity_level: int = 0): logging.basicConfig(handlers=(console_handler,), level=logging.NOTSET) -run_name_prompt = ( +RUN_NAME_PROMPT = ( "Generate a random name that sounds german or dutch. " "The parts should be separated by underscores and contain only lowercase. " "Only return the name without any text before or after." @@ -80,8 +80,10 @@ def initialize_run_directory(model): # make sure that we use high randomness for generating the run name even if a seed is set for the model response, _, _, _ = model.create_completion( system_message=None, - prompt=run_name_prompt, + prompt=RUN_NAME_PROMPT, use_randomness=True, + # a bit more randomness for the name is okay + temperature=1.2, ) run_name_match = re.search(r"^\w+$", response, re.MULTILINE) existing_run_names = os.listdir(RUNS_DIR) if RUNS_DIR.exists() else [] diff --git a/main.py b/main.py index 40c8d25dd7e2126e8f5acc83aa5548391e9b532f..ae7c518cffbe090d83aa570d1861021692d461de 100644 --- a/main.py +++ b/main.py @@ -59,8 +59,9 @@ if __name__ == "__main__": if options.wandb_project is not None: # init wandb and weave tracing (with disabled call link printing) - weave_settings = UserSettings(disabled=False, print_call_link=False) - weave.init(project_name=options.wandb_project, settings=weave_settings) + # TODO weave recently had 500 errors quite often, so we disable it for now + # weave_settings = UserSettings(disabled=False, print_call_link=False) + # weave.init(project_name=options.wandb_project, settings=weave_settings) wandb.init(project=options.wandb_project, config=options.__dict__) # set up console logging and rnd @@ -123,7 +124,6 @@ if __name__ == "__main__": population_size=10, task=task, evolution_model=evolution_model, - evaluation_model=evaluation_model, judge_model=judge_model, run_options=options.__dict__, )