Skip to content
Snippets Groups Projects

Use correct format for demonstration samples for evaluation and evolution

Merged Max Kimmich requested to merge refactor-models into master
1 file
+ 14
0
Compare changes
  • Side-by-side
  • Inline
from collections.abc import Iterable
import logging
import re
from abc import ABCMeta, abstractmethod
@@ -7,7 +8,6 @@ from tqdm import trange
import wandb
import weave
from evoprompt.evolution.template import get_demonstration_prompt_template
from evoprompt.evolution.template_de import (
DE_DEMONSTRATION_DATA_CLS,
DE_DEMONSTRATION_DATA_SIM,
@@ -23,7 +23,7 @@ from evoprompt.evolution.template_ga import (
GA_DEMONSTRATION_DATA_SIM,
GA_PROMPT,
)
from evoprompt.models import LLMModel
from evoprompt.models import ChatMessages, LLMModel
from evoprompt.opt_types import ModelUsage, Prompt
from evoprompt.optimization import Judgement, PromptOptimization
from evoprompt.task import Task
@@ -50,14 +50,12 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
*,
task: Task,
evolution_model: LLMModel,
evaluation_model: LLMModel,
judge_model: LLMModel | None,
run_options: dict[str, Any] = {},
) -> None:
super().__init__(
task=task,
evolution_model=evolution_model,
evaluation_model=evaluation_model,
judge_model=judge_model,
run_options=run_options,
)
@@ -93,6 +91,43 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
) -> tuple[str, list[Judgement], ModelUsage]:
pass
def build_demonstration_prompt(
self,
demonstration_samples: Iterable[tuple[str, str]],
instruction: str = None,
) -> ChatMessages:
return self.evolution_model.build_demonstration_data(
demonstration_samples,
instruction=instruction,
)
def parse_response(
self,
response: str,
start_token: str = "<prompt>",
end_token: str = "</prompt>",
allow_missing_end_token: bool = True,
):
# TODO another option would be to select the first match that is not equal to " and " (which is part of the instruction and usually repeated in the response)
matches = re.findall(
# regex that matches any characters between last pair of `start_token` and `end_token`, and optionally allow missing `end_token`
rf"{start_token}(?!.*{start_token})(?:(.*){end_token}{"|(.*)" if allow_missing_end_token else ""})",
response,
flags=(re.IGNORECASE | re.DOTALL),
)
if matches and any(matches[0]):
# there is always only a single match, and one group should be non-empty
if matches[0][0]:
evolved_prompt = matches[0][0]
else:
assert matches[0][1]
evolved_prompt = matches[0][1]
else:
# could not extract prompt -> no new prompt
evolved_prompt = None
return evolved_prompt
@abstractmethod
def update(self, *args, **kwargs):
pass
@@ -163,8 +198,8 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
if p_i is not None:
prompt_source = (
"corrected" # could also mean that user skipped the prompt
if not all(j.happy for j in judgements)
else "generated"
if False in [j.happy for j in judgements]
else "evolution"
)
evolved_prompt = self.add_prompt(
p_i,
@@ -195,7 +230,13 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
# Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D)
p = max(self.P[-1], key=lambda prompt: self.all_prompts[prompt.id].score)
logger.info("Best prompt with score %.2f: %s", p.score, p)
logger.info(
"Best prompt with score %.2f: %s (Source: %s - Gen: %d)",
p.score,
p,
p.meta["source"],
p.meta["gen"],
)
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_performance, _, _ = self.task.evaluate_test(p.content)
@@ -241,20 +282,23 @@ class GeneticAlgorithm(EvolutionAlgorithm):
# Based on this two-step process, we design instructions, guiding LLMs to
# generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
filled_prompt = self.get_prompt_template().format(
prompt = self._get_prompt_template()
if isinstance(prompt, tuple) and len(prompt) > 1:
# extract demonstrations
prompt, demo_messages = prompt
filled_prompt = prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
)
evolved_prompt, history, recent_turn, usage = (
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
use_randomness=True,
)
response, _, recent_turn, usage = self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
history=demo_messages if self.use_evolution_demo else None,
use_randomness=True,
)
judgement = self.judge_and_correct_step(
filled_prompt, evolved_prompt, history, recent_turn
filled_prompt, response, history=None, recent_turn=recent_turn
)
if judgement.skip:
@@ -265,17 +309,18 @@ class GeneticAlgorithm(EvolutionAlgorithm):
usage,
)
evolved_prompt = judgement.corrected_response
if "<prompt>" in evolved_prompt:
evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
response = judgement.corrected_response
evolved_prompt = self.parse_response(response)
logger.info(
"GA-evolved prompts '%s' and '%s' into '%s'",
prompt_1,
prompt_2,
evolved_prompt,
)
if evolved_prompt is None:
logger.info(f"Could not extract prompt from response: {evolved_prompt}")
else:
logger.info(
"GA-evolved prompts '%s' and '%s' into '%s'",
prompt_1,
prompt_2,
evolved_prompt,
)
return evolved_prompt, [judgement], usage
@@ -301,22 +346,28 @@ class GeneticAlgorithm(EvolutionAlgorithm):
return retained_prompts
def get_prompt_template(self):
def _get_prompt_template(self):
if self.use_evolution_demo:
if isinstance(
self.task, (TextClassification, Summarization, QuestionAnswering)
):
return get_demonstration_prompt_template(
GA_PROMPT, GA_DEMONSTRATION_DATA_SIM
)
demonstration_data = GA_DEMONSTRATION_DATA_SIM
elif isinstance(self.task, Simplification):
return get_demonstration_prompt_template(
GA_PROMPT, GA_DEMONSTRATION_DATA_CLS
)
demonstration_data = GA_DEMONSTRATION_DATA_CLS
else:
raise NotImplementedError(
f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
)
prompt_with_demonstration_data = self.build_demonstration_prompt(
[
(
GA_PROMPT.format(**demonstration_data),
demonstration_data["response"],
)
]
)
return GA_PROMPT, prompt_with_demonstration_data
return GA_PROMPT
@@ -342,59 +393,56 @@ class DifferentialEvolution(EvolutionAlgorithm):
prompts_current_evolution, key=lambda prompt: prompt.score
)
filled_prompt = self.get_prompt_template().format(
prompt = self._get_prompt_template()
if isinstance(prompt, tuple) and len(prompt) > 1:
# extract demonstrations
prompt, demo_messages = prompt
filled_prompt = prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
prompt3=best_prompt_current_evolution,
basic_prompt=prompts_current_evolution[current_iteration],
)
evolved_prompt, history, recent_turn, usage = (
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
use_randomness=True,
)
response, _, recent_turn, usage = self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
history=demo_messages if self.use_evolution_demo else None,
use_randomness=True,
)
judgement = self.judge_and_correct_step(
filled_prompt, evolved_prompt, history, recent_turn
filled_prompt, response, history=None, recent_turn=recent_turn
)
if judgement.skip:
# skip this prompt, for DE this means using the basic prompt
# user asked to skip this prompt, for DE this means using the basic prompt
return (
prompts_current_evolution[current_iteration].content,
[judgement],
usage,
)
evolved_prompt = judgement.corrected_response
response = judgement.corrected_response
evolved_prompt = self.parse_response(response)
matches = re.findall(
# regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing
r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))",
evolved_prompt,
flags=(re.IGNORECASE | re.DOTALL),
)
if matches and any(matches[0]):
# there is always only a single match, and one group should be non-empty
if matches[0][0]:
evolved_prompt = matches[0][0]
else:
assert matches[0][1]
evolved_prompt = matches[0][1]
if evolved_prompt is None:
logger.info(f"Could not extract prompt from response: {evolved_prompt}")
# no prompt was returned (e.g., evolved prompt could not be extracted), therefore, for DE, we use the basic prompt
return (
prompts_current_evolution[current_iteration].content,
[judgement],
usage,
)
else:
# TODO what to do in this case? Discard generated prompt directly?
pass
logger.info(
"DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
)
logger.info(
"DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
)
return evolved_prompt, [judgement], usage
@@ -412,22 +460,28 @@ class DifferentialEvolution(EvolutionAlgorithm):
]
return population
def get_prompt_template(self):
def _get_prompt_template(self):
if self.use_evolution_demo:
if isinstance(
self.task, (TextClassification, Summarization, QuestionAnswering)
):
return get_demonstration_prompt_template(
DE_PROMPT, DE_DEMONSTRATION_DATA_SIM
)
demonstration_data = DE_DEMONSTRATION_DATA_SIM
elif isinstance(self.task, Simplification):
return get_demonstration_prompt_template(
DE_PROMPT, DE_DEMONSTRATION_DATA_CLS
)
demonstration_data = DE_DEMONSTRATION_DATA_CLS
else:
raise NotImplementedError(
f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
)
prompt_with_demonstration_data = self.build_demonstration_prompt(
[
(
DE_PROMPT.format(**demonstration_data),
demonstration_data["response"],
)
]
)
return DE_PROMPT, prompt_with_demonstration_data
return DE_PROMPT
@@ -453,40 +507,64 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
prompts_current_evolution, key=lambda prompt: prompt.score
)
history = None
response: str = ""
# list of evolution steps
evolutions_steps = []
# list (turns) of list (demonstrations)
demos = [[]]
judgements: list[Judgement] = []
usage: ModelUsage = ModelUsage()
for idx, prompt in enumerate(self.get_prompt_template()):
for idx, prompt in enumerate(self._get_prompt_template()):
if isinstance(prompt, tuple) and len(prompt) > 1:
# extract demonstrations
prompt, demo_messages = prompt
demos[-1].extend(demo_messages)
if self.use_evolution_demo:
messages_demos = self.condense_messages(demos[-1])
else:
messages_demos = []
filled_prompt = prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
prompt3=best_prompt_current_evolution,
basic_prompt=prompts_current_evolution[current_iteration],
)
evolutions_steps.append(
self.evolution_model._get_user_message(filled_prompt)
)
# TODO Shall we still use only a single turn containing all messages if we do not use demonstrations for evolution?
prompt = self.condense_messages(evolutions_steps, return_str=True)
response, history, recent_turn, usage = (
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
history=history,
prompt=prompt,
history=messages_demos,
# the models often repeat the instuction which could also contain </prompt> therefore we should not stop early
stop=None, # "</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None,
stop=None,
use_randomness=True,
)
)
evolutions_steps.append(
self.evolution_model._get_assistant_message(response)
)
logger.debug(
"Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s",
idx,
history + recent_turn,
response,
)
# TODO use serialized messages as prompt or use previous evolution steps as history?
input(f"{len(evolutions_steps)}, \n{evolutions_steps}")
judgement = self.judge_and_correct_step(
filled_prompt, response, history=history, recent_turn=recent_turn
filled_prompt,
response,
history=evolutions_steps[:-2],
recent_turn=recent_turn,
# prompt, response, history=None, recent_turn=recent_turn
)
judgements.append(judgement)
if judgement.skip:
# skip this prompt, for DE this means using the basic prompt
# user asked to skip this prompt, for DE this means using the basic prompt
return (
prompts_current_evolution[current_iteration].content,
judgements,
@@ -494,28 +572,62 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
)
# replace last message with corrected response
recent_turn[-1]["content"] = judgement.corrected_response
response = judgement.corrected_response
# update history with recent turn
history += recent_turn
history.append(self.evolution_model._get_assistant_message(response))
# at this point we should get a new prompt
if "<prompt>" in response:
response = response.split("<prompt>")[1].split("</prompt>")[0]
evolved_prompt = self.parse_response(response)
logger.info(
"DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
response,
)
if evolved_prompt is None:
logger.info(f"Could not extract prompt from response: {evolved_prompt}")
# no prompt was returned (e.g., evolved prompt could not be extracted), therefore, for DE, we use the basic prompt
return (
prompts_current_evolution[current_iteration].content,
judgements,
usage,
)
else:
logger.info(
"DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
prompt_1,
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
)
return evolved_prompt, judgements, usage
def condense_messages(
self, messages: list[ChatMessages], return_str: bool = False
) -> list[dict] | str:
if not messages:
if return_str:
return ""
return []
return response, judgements, usage
if messages[-1]["role"] == "assistant":
assistant_turn = messages[-1]
messages = messages[:-1]
else:
assistant_turn = None
user_turn = "\n\n".join(message["content"] for message in messages)
if return_str:
assert (
assistant_turn is None
), "Cannot return string if most recent turn is from assistant."
return user_turn
messages = [self.evolution_model._get_user_message(user_turn)]
if assistant_turn is not None:
messages.append(assistant_turn)
return messages
def get_prompt_template(self):
def _get_prompt_template(self):
if self.use_evolution_demo:
if isinstance(
self.task, (TextClassification, Summarization, QuestionAnswering)
@@ -528,9 +640,18 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
f"Prompt with demonstration data is not implemented for task of type {type(self.task)}."
)
for prompt, demonstration_data_item in zip(
for prompt_template, demonstration_data_item in zip(
DE_COT_PROMPTS, demonstration_data
):
yield get_demonstration_prompt_template(prompt, demonstration_data_item)
prompt_with_demonstration_data = self.build_demonstration_prompt(
[
(
prompt_template.format(**demonstration_data_item),
demonstration_data_item["response"],
)
]
)
yield prompt_template, prompt_with_demonstration_data
# TODO how can we add a generation prefix for the model?
else:
yield from DE_COT_PROMPTS
Loading