From f57cbe077d97f825da356ccbf3ce05e6c0575fce Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Thu, 10 Oct 2024 16:12:30 +0200 Subject: [PATCH] Use chat format for evolution demonstration samples --- evoprompt/evolution/evolution.py | 259 +++++++++++++++++++++---------- evoprompt/evolution/template.py | 6 - 2 files changed, 173 insertions(+), 92 deletions(-) delete mode 100644 evoprompt/evolution/template.py diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index ce57de5..b5d71f9 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import logging import re from abc import ABCMeta, abstractmethod @@ -7,7 +8,6 @@ from tqdm import trange import wandb import weave -from evoprompt.evolution.template import get_demonstration_prompt_template from evoprompt.evolution.template_de import ( DE_DEMONSTRATION_DATA_CLS, DE_DEMONSTRATION_DATA_SIM, @@ -91,6 +91,42 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): ) -> tuple[str, list[Judgement], ModelUsage]: pass + def build_demonstration_prompt( + self, + demonstration_samples: Iterable[tuple[str, str]], + instruction: str = None, + ) -> ChatMessages: + return self.evolution_model.build_demonstration_data( + demonstration_samples, + instruction=instruction, + ) + + def parse_response( + self, + response: str, + start_token: str = "<prompt>", + end_token: str = "</prompt>", + allow_missing_end_token: bool = True, + ): + matches = re.findall( + # regex that matches any characters between last pair of `start_token` and `end_token`, and optionally allow missing `end_token` + rf"{start_token}(?!.*{start_token})(?:(.*){end_token}{"|(.*)" if allow_missing_end_token else ""})", + response, + flags=(re.IGNORECASE | re.DOTALL), + ) + if matches and any(matches[0]): + # there is always only a single match, and one group should be non-empty + if matches[0][0]: + evolved_prompt = matches[0][0] + else: + assert matches[0][1] + evolved_prompt = matches[0][1] + + else: + # could not extract prompt -> no new prompt + evolved_prompt = None + return evolved_prompt + @abstractmethod def update(self, *args, **kwargs): pass @@ -245,20 +281,23 @@ class GeneticAlgorithm(EvolutionAlgorithm): # Based on this two-step process, we design instructions, guiding LLMs to # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. - filled_prompt = self.get_prompt_template().format( + prompt = self._get_prompt_template() + if isinstance(prompt, tuple) and len(prompt) > 1: + # extract demonstrations + prompt, demo_messages = prompt + filled_prompt = prompt.format( prompt1=prompt_1, prompt2=prompt_2, ) - evolved_prompt, history, recent_turn, usage = ( - self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - use_randomness=True, - ) + response, history, recent_turn, usage = self.evolution_model.create_completion( + system_message=SYSTEM_MESSAGE, + prompt=filled_prompt, + history=demo_messages if self.use_evolution_demo else None, + use_randomness=True, ) judgement = self.judge_and_correct_step( - filled_prompt, evolved_prompt, history, recent_turn + filled_prompt, response, history, recent_turn ) if judgement.skip: @@ -269,17 +308,18 @@ class GeneticAlgorithm(EvolutionAlgorithm): usage, ) - evolved_prompt = judgement.corrected_response - - if "<prompt>" in evolved_prompt: - evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] + response = judgement.corrected_response + evolved_prompt = self.parse_response(response) - logger.info( - "GA-evolved prompts '%s' and '%s' into '%s'", - prompt_1, - prompt_2, - evolved_prompt, - ) + if evolved_prompt is None: + logger.info(f"Could not extract prompt from response: {evolved_prompt}") + else: + logger.info( + "GA-evolved prompts '%s' and '%s' into '%s'", + prompt_1, + prompt_2, + evolved_prompt, + ) return evolved_prompt, [judgement], usage @@ -305,22 +345,28 @@ class GeneticAlgorithm(EvolutionAlgorithm): return retained_prompts - def get_prompt_template(self): + def _get_prompt_template(self): if self.use_evolution_demo: if isinstance( self.task, (TextClassification, Summarization, QuestionAnswering) ): - return get_demonstration_prompt_template( - GA_PROMPT, GA_DEMONSTRATION_DATA_SIM - ) + demonstration_data = GA_DEMONSTRATION_DATA_SIM elif isinstance(self.task, Simplification): - return get_demonstration_prompt_template( - GA_PROMPT, GA_DEMONSTRATION_DATA_CLS - ) + demonstration_data = GA_DEMONSTRATION_DATA_CLS else: raise NotImplementedError( f"Prompt with demonstration data is not implemented for task of type {type(self.task)}." ) + + prompt_with_demonstration_data = self.build_demonstration_prompt( + [ + ( + GA_PROMPT.format(**demonstration_data), + demonstration_data["response"], + ) + ] + ) + return GA_PROMPT, prompt_with_demonstration_data return GA_PROMPT @@ -346,22 +392,25 @@ class DifferentialEvolution(EvolutionAlgorithm): prompts_current_evolution, key=lambda prompt: prompt.score ) - filled_prompt = self.get_prompt_template().format( + prompt = self._get_prompt_template() + if isinstance(prompt, tuple) and len(prompt) > 1: + # extract demonstrations + prompt, demo_messages = prompt + filled_prompt = prompt.format( prompt1=prompt_1, prompt2=prompt_2, prompt3=best_prompt_current_evolution, basic_prompt=prompts_current_evolution[current_iteration], ) - evolved_prompt, history, recent_turn, usage = ( - self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - use_randomness=True, - ) + response, history, recent_turn, usage = self.evolution_model.create_completion( + system_message=SYSTEM_MESSAGE, + prompt=filled_prompt, + history=demo_messages if self.use_evolution_demo else None, + use_randomness=True, ) judgement = self.judge_and_correct_step( - filled_prompt, evolved_prompt, history, recent_turn + filled_prompt, response, history, recent_turn ) if judgement.skip: @@ -372,33 +421,20 @@ class DifferentialEvolution(EvolutionAlgorithm): usage, ) - evolved_prompt = judgement.corrected_response + response = judgement.corrected_response + evolved_prompt = self.parse_response(response) - matches = re.findall( - # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing - r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))", - evolved_prompt, - flags=(re.IGNORECASE | re.DOTALL), - ) - if matches and any(matches[0]): - # there is always only a single match, and one group should be non-empty - if matches[0][0]: - evolved_prompt = matches[0][0] - else: - assert matches[0][1] - evolved_prompt = matches[0][1] + if evolved_prompt is None: + logger.info(f"Could not extract prompt from response: {evolved_prompt}") else: - # TODO what to do in this case? Discard generated prompt directly? - pass - - logger.info( - "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", - prompt_1, - prompt_2, - best_prompt_current_evolution, - prompts_current_evolution[current_iteration], - evolved_prompt, - ) + logger.info( + "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", + prompt_1, + prompt_2, + best_prompt_current_evolution, + prompts_current_evolution[current_iteration], + evolved_prompt, + ) return evolved_prompt, [judgement], usage @@ -416,22 +452,28 @@ class DifferentialEvolution(EvolutionAlgorithm): ] return population - def get_prompt_template(self): + def _get_prompt_template(self): if self.use_evolution_demo: if isinstance( self.task, (TextClassification, Summarization, QuestionAnswering) ): - return get_demonstration_prompt_template( - DE_PROMPT, DE_DEMONSTRATION_DATA_SIM - ) + demonstration_data = DE_DEMONSTRATION_DATA_SIM elif isinstance(self.task, Simplification): - return get_demonstration_prompt_template( - DE_PROMPT, DE_DEMONSTRATION_DATA_CLS - ) + demonstration_data = DE_DEMONSTRATION_DATA_CLS else: raise NotImplementedError( f"Prompt with demonstration data is not implemented for task of type {type(self.task)}." ) + + prompt_with_demonstration_data = self.build_demonstration_prompt( + [ + ( + DE_PROMPT.format(**demonstration_data), + demonstration_data["response"], + ) + ] + ) + return DE_PROMPT, prompt_with_demonstration_data return DE_PROMPT @@ -457,27 +499,46 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): prompts_current_evolution, key=lambda prompt: prompt.score ) - history = None + # list of evolution steps + evolutions_steps = [] + # list (turns) of list (demonstrations) + demos = [[]] response: str = "" judgements: list[Judgement] = [] usage: ModelUsage = ModelUsage() - for idx, prompt in enumerate(self.get_prompt_template()): + for idx, prompt in enumerate(self._get_prompt_template()): + if isinstance(prompt, tuple) and len(prompt) > 1: + # extract demonstrations + prompt, demo_messages = prompt + demos[-1].extend(demo_messages) + if self.use_evolution_demo: + messages_demos = self.condense_messages(demos[-1]) + else: + messages_demos = [] filled_prompt = prompt.format( prompt1=prompt_1, prompt2=prompt_2, prompt3=best_prompt_current_evolution, basic_prompt=prompts_current_evolution[current_iteration], ) + evolutions_steps.append( + self.evolution_model._get_user_message(filled_prompt) + ) + # TODO Shall we still use only a single turn containing all messages if we do not use demonstrations for evolution? + messages = self.condense_messages(evolutions_steps) response, history, recent_turn, usage = ( self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - history=history, + messages=messages, + history=messages_demos, # the models often repeat the instuction which could also contain </prompt> therefore we should not stop early - stop=None, # "</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None, + stop=None, use_randomness=True, ) ) + evolutions_steps.append( + self.evolution_model._get_assistant_message(response) + ) logger.debug( "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s", idx, @@ -504,22 +565,39 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): # update history with recent turn history += recent_turn - # at this point we should get a new prompt - if "<prompt>" in response: - response = response.split("<prompt>")[1].split("</prompt>")[0] + evolved_prompt = self.parse_response(response) - logger.info( - "DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", - prompt_1, - prompt_2, - best_prompt_current_evolution, - prompts_current_evolution[current_iteration], - response, - ) + if evolved_prompt is None: + logger.info(f"Could not extract prompt from response: {evolved_prompt}") + else: + logger.info( + "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", + prompt_1, + prompt_2, + best_prompt_current_evolution, + prompts_current_evolution[current_iteration], + evolved_prompt, + ) + + return evolved_prompt, judgements, usage + + def condense_messages(self, messages: list[ChatMessages]) -> list[dict]: + if not messages: + return [] - return response, judgements, usage + if messages[-1]["role"] == "assistant": + assistant_turn = messages[-1] + messages = messages[:-1] + else: + assistant_turn = None + + user_turn = "\n\n".join(message["content"] for message in messages) + messages = [self.evolution_model._get_user_message(user_turn)] + if assistant_turn is not None: + messages.append(assistant_turn) + return messages - def get_prompt_template(self): + def _get_prompt_template(self): if self.use_evolution_demo: if isinstance( self.task, (TextClassification, Summarization, QuestionAnswering) @@ -532,9 +610,18 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): f"Prompt with demonstration data is not implemented for task of type {type(self.task)}." ) - for prompt, demonstration_data_item in zip( + for prompt_template, demonstration_data_item in zip( DE_COT_PROMPTS, demonstration_data ): - yield get_demonstration_prompt_template(prompt, demonstration_data_item) + prompt_with_demonstration_data = self.build_demonstration_prompt( + [ + ( + prompt_template.format(**demonstration_data_item), + demonstration_data_item["response"], + ) + ] + ) + yield prompt_template, prompt_with_demonstration_data + # TODO how can we add a generation prefix for the model? else: yield from DE_COT_PROMPTS diff --git a/evoprompt/evolution/template.py b/evoprompt/evolution/template.py deleted file mode 100644 index a84dbb3..0000000 --- a/evoprompt/evolution/template.py +++ /dev/null @@ -1,6 +0,0 @@ -def get_demonstration_prompt_template(prompt_template: str, demonstration_data: dict): - prompt_template_with_demo = prompt_template.format(**demonstration_data) - prompt_template_with_demo += "\n\n" + demonstration_data["response"] - prompt_template_with_demo += "\n\n" + prompt_template - prompt_template_with_demo += "\n\n" + demonstration_data["generation_prefix"] - return prompt_template_with_demo -- GitLab