Skip to content
Snippets Groups Projects
Commit f57cbe07 authored by Max Kimmich's avatar Max Kimmich
Browse files

Use chat format for evolution demonstration samples

parent 975b2fd8
No related branches found
No related tags found
1 merge request!8Use correct format for demonstration samples for evaluation and evolution
from collections.abc import Iterable
import logging
import re
from abc import ABCMeta, abstractmethod
......@@ -7,7 +8,6 @@ from tqdm import trange
import wandb
import weave
from evoprompt.evolution.template import get_demonstration_prompt_template
from evoprompt.evolution.template_de import (
DE_DEMONSTRATION_DATA_CLS,
DE_DEMONSTRATION_DATA_SIM,
......@@ -91,6 +91,42 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
) -> tuple[str, list[Judgement], ModelUsage]:
pass
def build_demonstration_prompt(
self,
demonstration_samples: Iterable[tuple[str, str]],
instruction: str = None,
) -> ChatMessages:
return self.evolution_model.build_demonstration_data(
demonstration_samples,
instruction=instruction,
)
def parse_response(
self,
response: str,
start_token: str = "<prompt>",
end_token: str = "</prompt>",
allow_missing_end_token: bool = True,
):
matches = re.findall(
# regex that matches any characters between last pair of `start_token` and `end_token`, and optionally allow missing `end_token`
rf"{start_token}(?!.*{start_token})(?:(.*){end_token}{"|(.*)" if allow_missing_end_token else ""})",
response,
flags=(re.IGNORECASE | re.DOTALL),
)
if matches and any(matches[0]):
# there is always only a single match, and one group should be non-empty
if matches[0][0]:
evolved_prompt = matches[0][0]
else:
assert matches[0][1]
evolved_prompt = matches[0][1]
else:
# could not extract prompt -> no new prompt
evolved_prompt = None
return evolved_prompt
@abstractmethod
def update(self, *args, **kwargs):
pass
......@@ -245,20 +281,23 @@ class GeneticAlgorithm(EvolutionAlgorithm):
# Based on this two-step process, we design instructions, guiding LLMs to
# generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
filled_prompt = self.get_prompt_template().format(
prompt = self._get_prompt_template()
if isinstance(prompt, tuple) and len(prompt) > 1:
# extract demonstrations
prompt, demo_messages = prompt
filled_prompt = prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
)
evolved_prompt, history, recent_turn, usage = (
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
use_randomness=True,
)
response, history, recent_turn, usage = self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
history=demo_messages if self.use_evolution_demo else None,
use_randomness=True,
)
judgement = self.judge_and_correct_step(
filled_prompt, evolved_prompt, history, recent_turn
filled_prompt, response, history, recent_turn
)
if judgement.skip:
......@@ -269,17 +308,18 @@ class GeneticAlgorithm(EvolutionAlgorithm):
usage,
)
evolved_prompt = judgement.corrected_response
if "<prompt>" in evolved_prompt:
evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
response = judgement.corrected_response
evolved_prompt = self.parse_response(response)
logger.info(
"GA-evolved prompts '%s' and '%s' into '%s'",
prompt_1,
prompt_2,
evolved_prompt,
)
if evolved_prompt is None:
logger.info(f"Could not extract prompt from response: {evolved_prompt}")
else:
logger.info(
"GA-evolved prompts '%s' and '%s' into '%s'",
prompt_1,
prompt_2,
evolved_prompt,
)
return evolved_prompt, [judgement], usage
......@@ -305,22 +345,28 @@ class GeneticAlgorithm(EvolutionAlgorithm):
return retained_prompts
def get_prompt_template(self):
def _get_prompt_template(self):
if self.use_evolution_demo:
if isinstance(
self.task, (TextClassification, Summarization, QuestionAnswering)
):
return get_demonstration_prompt_template(
GA_PROMPT, GA_DEMONSTRATION_DATA_SIM
)
demonstration_data = GA_DEMONSTRATION_DATA_SIM
elif isinstance(self.task, Simplification):
return get_demonstration_prompt_template(
GA_PROMPT, GA_DEMONSTRATION_DATA_CLS
)
demonstration_data = GA_DEMONSTRATION_DATA_CLS
else:
raise NotImplementedError(
f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
)
prompt_with_demonstration_data = self.build_demonstration_prompt(
[
(
GA_PROMPT.format(**demonstration_data),
demonstration_data["response"],
)
]
)
return GA_PROMPT, prompt_with_demonstration_data
return GA_PROMPT
......@@ -346,22 +392,25 @@ class DifferentialEvolution(EvolutionAlgorithm):
prompts_current_evolution, key=lambda prompt: prompt.score
)
filled_prompt = self.get_prompt_template().format(
prompt = self._get_prompt_template()
if isinstance(prompt, tuple) and len(prompt) > 1:
# extract demonstrations
prompt, demo_messages = prompt
filled_prompt = prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
prompt3=best_prompt_current_evolution,
basic_prompt=prompts_current_evolution[current_iteration],
)
evolved_prompt, history, recent_turn, usage = (
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
use_randomness=True,
)
response, history, recent_turn, usage = self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
history=demo_messages if self.use_evolution_demo else None,
use_randomness=True,
)
judgement = self.judge_and_correct_step(
filled_prompt, evolved_prompt, history, recent_turn
filled_prompt, response, history, recent_turn
)
if judgement.skip:
......@@ -372,33 +421,20 @@ class DifferentialEvolution(EvolutionAlgorithm):
usage,
)
evolved_prompt = judgement.corrected_response
response = judgement.corrected_response
evolved_prompt = self.parse_response(response)
matches = re.findall(
# regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing
r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))",
evolved_prompt,
flags=(re.IGNORECASE | re.DOTALL),
)
if matches and any(matches[0]):
# there is always only a single match, and one group should be non-empty
if matches[0][0]:
evolved_prompt = matches[0][0]
else:
assert matches[0][1]
evolved_prompt = matches[0][1]
if evolved_prompt is None:
logger.info(f"Could not extract prompt from response: {evolved_prompt}")
else:
# TODO what to do in this case? Discard generated prompt directly?
pass
logger.info(
"DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
)
logger.info(
"DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
)
return evolved_prompt, [judgement], usage
......@@ -416,22 +452,28 @@ class DifferentialEvolution(EvolutionAlgorithm):
]
return population
def get_prompt_template(self):
def _get_prompt_template(self):
if self.use_evolution_demo:
if isinstance(
self.task, (TextClassification, Summarization, QuestionAnswering)
):
return get_demonstration_prompt_template(
DE_PROMPT, DE_DEMONSTRATION_DATA_SIM
)
demonstration_data = DE_DEMONSTRATION_DATA_SIM
elif isinstance(self.task, Simplification):
return get_demonstration_prompt_template(
DE_PROMPT, DE_DEMONSTRATION_DATA_CLS
)
demonstration_data = DE_DEMONSTRATION_DATA_CLS
else:
raise NotImplementedError(
f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
)
prompt_with_demonstration_data = self.build_demonstration_prompt(
[
(
DE_PROMPT.format(**demonstration_data),
demonstration_data["response"],
)
]
)
return DE_PROMPT, prompt_with_demonstration_data
return DE_PROMPT
......@@ -457,27 +499,46 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
prompts_current_evolution, key=lambda prompt: prompt.score
)
history = None
# list of evolution steps
evolutions_steps = []
# list (turns) of list (demonstrations)
demos = [[]]
response: str = ""
judgements: list[Judgement] = []
usage: ModelUsage = ModelUsage()
for idx, prompt in enumerate(self.get_prompt_template()):
for idx, prompt in enumerate(self._get_prompt_template()):
if isinstance(prompt, tuple) and len(prompt) > 1:
# extract demonstrations
prompt, demo_messages = prompt
demos[-1].extend(demo_messages)
if self.use_evolution_demo:
messages_demos = self.condense_messages(demos[-1])
else:
messages_demos = []
filled_prompt = prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
prompt3=best_prompt_current_evolution,
basic_prompt=prompts_current_evolution[current_iteration],
)
evolutions_steps.append(
self.evolution_model._get_user_message(filled_prompt)
)
# TODO Shall we still use only a single turn containing all messages if we do not use demonstrations for evolution?
messages = self.condense_messages(evolutions_steps)
response, history, recent_turn, usage = (
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
history=history,
messages=messages,
history=messages_demos,
# the models often repeat the instuction which could also contain </prompt> therefore we should not stop early
stop=None, # "</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None,
stop=None,
use_randomness=True,
)
)
evolutions_steps.append(
self.evolution_model._get_assistant_message(response)
)
logger.debug(
"Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s",
idx,
......@@ -504,22 +565,39 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
# update history with recent turn
history += recent_turn
# at this point we should get a new prompt
if "<prompt>" in response:
response = response.split("<prompt>")[1].split("</prompt>")[0]
evolved_prompt = self.parse_response(response)
logger.info(
"DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
response,
)
if evolved_prompt is None:
logger.info(f"Could not extract prompt from response: {evolved_prompt}")
else:
logger.info(
"DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
)
return evolved_prompt, judgements, usage
def condense_messages(self, messages: list[ChatMessages]) -> list[dict]:
if not messages:
return []
return response, judgements, usage
if messages[-1]["role"] == "assistant":
assistant_turn = messages[-1]
messages = messages[:-1]
else:
assistant_turn = None
user_turn = "\n\n".join(message["content"] for message in messages)
messages = [self.evolution_model._get_user_message(user_turn)]
if assistant_turn is not None:
messages.append(assistant_turn)
return messages
def get_prompt_template(self):
def _get_prompt_template(self):
if self.use_evolution_demo:
if isinstance(
self.task, (TextClassification, Summarization, QuestionAnswering)
......@@ -532,9 +610,18 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
)
for prompt, demonstration_data_item in zip(
for prompt_template, demonstration_data_item in zip(
DE_COT_PROMPTS, demonstration_data
):
yield get_demonstration_prompt_template(prompt, demonstration_data_item)
prompt_with_demonstration_data = self.build_demonstration_prompt(
[
(
prompt_template.format(**demonstration_data_item),
demonstration_data_item["response"],
)
]
)
yield prompt_template, prompt_with_demonstration_data
# TODO how can we add a generation prefix for the model?
else:
yield from DE_COT_PROMPTS
def get_demonstration_prompt_template(prompt_template: str, demonstration_data: dict):
prompt_template_with_demo = prompt_template.format(**demonstration_data)
prompt_template_with_demo += "\n\n" + demonstration_data["response"]
prompt_template_with_demo += "\n\n" + prompt_template
prompt_template_with_demo += "\n\n" + demonstration_data["generation_prefix"]
return prompt_template_with_demo
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment