From 31c10c46c9bc286abf999ee092a2d5763c18945c Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 3 Sep 2024 17:11:41 +0200 Subject: [PATCH] Fix topic classification and question answering tasks --- evoprompt/evolution.py | 412 ------------------------- evoprompt/evolution/evolution.py | 4 +- evoprompt/task/topic_classification.py | 10 +- 3 files changed, 11 insertions(+), 415 deletions(-) delete mode 100644 evoprompt/evolution.py diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py deleted file mode 100644 index 0bf36db..0000000 --- a/evoprompt/evolution.py +++ /dev/null @@ -1,412 +0,0 @@ -import logging -import re -from abc import ABCMeta, abstractmethod -from typing import Any - -from tqdm import trange - -from evoprompt.cli import argument_parser -from evoprompt.evolution.template_de import get_de_prompt_template -from evoprompt.models import LLMModel -from evoprompt.opt_types import ModelUsage, Prompt -from evoprompt.optimization import Judgement, PromptOptimization -from evoprompt.task import Task -from evoprompt.utils import get_all_subclasses, get_rng, log_calls - -logger = logging.getLogger(__name__) - - -SYSTEM_MESSAGE = "Please carefully follow the instruction step-by-step." - -GA_PROMPT = """ -1. Cross over the following prompts and generate a new prompt: -Prompt 1: {prompt1} -Prompt 2: {prompt2} -2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>. -""" - - -DE_COT_PROMPTS = [ - "Step 1: Identify the main different parts between the Prompt 1 and Prompt 2:\nPrompt 1: {prompt1}\nPrompt 2: {prompt2}", - "Step 2: Randomly mutate the different parts", - "Step 3: Combine the different parts with Prompt 3, selectively replace it with the different parts in Step 2 and generate a new prompt.\nPrompt 3: {prompt3}", - "Step 4: Cross over the prompt in the Step 3 with the following basic prompt and generate a final prompt bracketed with <prompt> and </prompt>:\nBasic Prompt: {basic_prompt}", -] - - -class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): - shorthand: str - - """The super class for all evolution algorithms containing shared parameters.""" - - def __init__( - self, - population_size: int, - *, - task: Task, - evolution_model: LLMModel, - evaluation_model: LLMModel, - judge_model: LLMModel | None, - run_options: dict[str, Any] = {}, - ) -> None: - super().__init__( - task=task, - evolution_model=evolution_model, - evaluation_model=evaluation_model, - judge_model=judge_model, - run_options=run_options, - ) - self.use_evolution_demo = run_options.get("use_evolution_demo", False) - - self.population_size = population_size - - @log_calls("Performing selection") - def select(self, prompts: list[Prompt]): - # In GA, two parent solutions are normally selected based on the roulette wheel - # selection method according to the fitness value (Lipowski & Lipowska, 2012). - # Similar to this, we utilize the roulette wheel selection method to select - # two parent prompts in the current population according to the scores evaluated - # on development sets. Specifically, let si denote the performance score on the - # development set of the i-th prompt in the population, which contains a total - # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as - # pi = si / Σj=1->N sj. - # add small value to avoid zero chance of selection for some prompts - scores = [prompt.score + 1e-6 for prompt in prompts] - selection_probabilities = [score / sum(scores) for score in scores] - return get_rng().choice( - prompts, size=2, replace=False, p=selection_probabilities - ) - - @abstractmethod - def evolve( - self, - prompt_1: str, - prompt_2: str, - *, - prompts_current_evolution: list[Prompt], - current_iteration: int, - ) -> tuple[str, list[Judgement], ModelUsage]: - pass - - @abstractmethod - def update(self, *args, **kwargs): - pass - - def run(self, num_iterations: int, debug: bool = False) -> None: - # debug mode for quick run - if debug: - self.population_size = 3 - num_iterations = 2 - - self.init_run(self.population_size, num_iterations, debug=debug) - - # Algorithm 1 Discrete prompt optimization: EVOPROMPT - - # Line 2: - for t in self.iterations_pbar: - # Line 3: Selection: select a certain number of prompts from current population as parent prompts - # pr1,...,prk ∼ Pt−1 - prompts_current_evolution = self.P[t - 1] - - new_evolutions = [] - - for i in trange(self.population_size, desc="updates", leave=False): - # for both GA and DE we start with two parent prompts - pr1, pr2 = self.select(self.P[t - 1]) - - # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators - # p′i â†Evo(pr1,...,prk) - ( - p_i, - judgements, - evolution_usage, - ) = self.evolve( - pr1, - pr2, - prompts_current_evolution=prompts_current_evolution, - current_iteration=i, - ) - self.total_evolution_usage += evolution_usage - - prompt_source = ( - "corrected" if not all(j.happy for j in judgements) else "generated" - ) - evolved_prompt = self.add_prompt( - p_i, - parents=(pr1, pr2), - meta={"gen": t, "source": prompt_source, "judgements": judgements}, - ) - self.total_evaluation_usage += evolved_prompt.usage - - new_evolutions.append(evolved_prompt) - self.save_snapshot() - # Line 6: Update based on the evaluation scores - # Pt ↠{Pt−1, p′i} and St ↠{St−1, s′i} - new_population = self.update(new_evolutions, prompts_current_evolution) - - # store new generation - self.P.append(new_population) - self.save_snapshot() - - self.save_snapshot() - # Line 8: Return the best prompt, p∗, among the final population PT : - # p∗ ↠argmaxp∈PT f(p, D) - p = max(self.P[-1], key=lambda prompt: self.all_prompts[prompt.id].score) - logger.info("Best prompt with score %.2f: %s", p.score, p) - - # We pick the prompt with the highest score on the development set and report its score on the testset. - test_performance, _, _ = self.task.evaluate_test(p.content) - logger.info( - "Best prompt on test set: %s %s", test_performance, self.task.metric_name - ) - logger.info( - "Usage (evolution model / evaluation model / total): %s / %s / %s", - self.total_evolution_usage, - self.total_evaluation_usage, - self.total_evolution_usage + self.total_evaluation_usage, - ) - - return self.total_evolution_usage, self.total_evaluation_usage - - -class GeneticAlgorithm(EvolutionAlgorithm): - """The genetic algorithm (GA) implemented using LLMs.""" - - shorthand = "ga" - - # kwargs is just there for convenience, as evolve function of other optimizers might have different inputs - # @register_action(ignore_args=["kwargs"]) - @log_calls("Performing prompt evolution using GA") - def evolve( - self, - prompt_1: str, - prompt_2: str, - **kwargs, - ): - # Following the evolutionary operators in GA, a new candidate prompt is generated through - # a two-step process based on the selected two parents: - # 1) The parent prompts undergo crossover, resulting in a new prompt that - # selectively combines components from both parents; - # 2) The newly generated prompt from the first step undergoes mutation, - # in which random alterations are made to some of its content. - # Based on this two-step process, we design instructions, guiding LLMs to - # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. - - filled_prompt = GA_PROMPT.format(prompt1=prompt_1, prompt2=prompt_2) - evolved_prompt, messages, usage = self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - ) - - judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages) - evolved_prompt = judgement.corrected_response - - if "<prompt>" in evolved_prompt: - evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] - - logger.info( - "GA-evolved prompts '%s' and '%s' into '%s'", - prompt_1, - prompt_2, - evolved_prompt, - ) - - return evolved_prompt, [judgement], usage - - @log_calls("Performing update for GA") - def update( - self, prompts_current_evolution: list[Prompt], new_evolutions: list[Prompt] - ): - # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt - # using a development set, denoted as D, to obtain a score that quantifies the - # quality of the prompt. We consider a straightforward selection strategy. - # Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts, - # which are combined with the current population of N prompts. - # The updated population is then selected by retaining the N prompts with the highest scores. - retained_prompts: list[Prompt] = [] - min_retained_score = 0 - for prompt in prompts_current_evolution + new_evolutions: - if len(retained_prompts) < self.population_size: - retained_prompts.append(prompt) - min_retained_score = min(min_retained_score, prompt.score) - elif prompt.score > min_retained_score: - retained_prompts.sort(key=lambda p: p.score) - retained_prompts[0] = prompt - - return retained_prompts - - -class DifferentialEvolution(EvolutionAlgorithm): - """The differential algorithm (DE) implemented using LLMs.""" - - shorthand = "de" - - @log_calls("Performing prompt evolution using DE") - def evolve( - self, - prompt_1: str, - prompt_2: str, - *, - prompts_current_evolution: list[Prompt], - current_iteration: int, - ): - # TODO add description from paper - - # DE needs best prompt for evolution - best_prompt_current_evolution = max( - prompts_current_evolution, key=lambda prompt: prompt.score - ) - - filled_prompt = get_de_prompt_template( - self.use_evolution_demo, self.task - ).format( - prompt1=prompt_1, - prompt2=prompt_2, - prompt3=best_prompt_current_evolution, - basic_prompt=prompts_current_evolution[current_iteration], - ) - evolved_prompt, messages, usage = self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - ) - - judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages) - evolved_prompt = judgement.corrected_response - - matches = re.findall( - # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing - r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))", - evolved_prompt, - flags=(re.IGNORECASE | re.DOTALL), - ) - if matches and any(matches[0]): - # there is always only a single match, and one group should be non-empty - if matches[0][0]: - evolved_prompt = matches[0][0] - else: - assert matches[0][1] - evolved_prompt = matches[0][1] - else: - # TODO what to do in this case? Discard generated prompt directly? - pass - - logger.info( - "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", - prompt_1, - prompt_2, - best_prompt_current_evolution, - prompts_current_evolution[current_iteration], - evolved_prompt, - ) - - return evolved_prompt, [judgement], usage - - @log_calls("Performing update for DE") - def update( - self, prompts_current_evolution: list[Prompt], new_evolutions: list[Prompt] - ): - # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise - assert len(prompts_current_evolution) == len(new_evolutions) - population = [ - (new_prompt if new_prompt.score > current_prompt.score else current_prompt) - for current_prompt, new_prompt in zip( - prompts_current_evolution, new_evolutions - ) - ] - return population - - -class DifferentialEvolutionWithCot(DifferentialEvolution): - """The differential algorithm using Chain-of-Thought (DE-CoT) implemented using LLMs.""" - - shorthand = "de-cot" - - @log_calls("Performing prompt evolution using DE-CoT") - def evolve( - self, - prompt_1: str, - prompt_2: str, - *, - prompts_current_evolution: list[Prompt], - current_iteration: int, - ): - # TODO add description - - # DE needs best prompt for evolution - best_prompt_current_evolution = max( - prompts_current_evolution, key=lambda prompt: prompt.score - ) - - messages = None - response: str = "" - judgements: list[Judgement] = [] - usage: ModelUsage = ModelUsage() - for idx, prompt in enumerate(DE_COT_PROMPTS): - filled_prompt = prompt.format( - prompt1=prompt_1, - prompt2=prompt_2, - prompt3=best_prompt_current_evolution, - basic_prompt=prompts_current_evolution[current_iteration], - ) - response, messages, usage = self.evolution_model.create_completion( - system_message=SYSTEM_MESSAGE, - prompt=filled_prompt, - history=messages, - stop="</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None, - ) - logger.debug( - "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s", - idx, - messages, - response, - ) - judgement = self.judge_and_correct_step( - filled_prompt, response, history=messages - ) - judgements.append(judgement) - # replace last message with corrected response - messages[-1]["content"] = judgement.corrected_response - response = judgement.corrected_response - - # at this point we should get a new prompt - if "<prompt>" in response: - response = response.split("<prompt>")[1].split("</prompt>")[0] - - logger.info( - "DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'", - prompt_1, - prompt_2, - best_prompt_current_evolution, - prompts_current_evolution[current_iteration], - response, - ) - - return response, judgements, usage - - -def get_all_subclasses(cls): - return set(cls.__subclasses__()).union( - [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] - ) - - -optimizers = { - algorithm.shorthand: algorithm - for algorithm in get_all_subclasses(EvolutionAlgorithm) -} - - -def get_optimizer_class(name: str) -> type[EvolutionAlgorithm]: - if name not in optimizers: - raise ValueError("Optimization Algorithm %s does not exist", name) - return optimizers[name] - - -argument_parser.add_argument( - "--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga" -) -argument_parser.add_argument( - "--use-evolution-demo", - action="store_true", - help="Whether to prepend a single demonstration example for evolution or not", -) diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index eb969f8..60a9cbf 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -336,7 +336,9 @@ class DifferentialEvolution(EvolutionAlgorithm): def get_prompt_template(self): if self.use_evolution_demo: - if isinstance(self.task, (TextClassification, Summarization)): + if isinstance( + self.task, (TextClassification, Summarization, QuestionAnswering) + ): return get_demonstration_prompt_template( DE_PROMPT, DE_DEMONSTRATION_DATA_SIM ) diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index 6e9107b..c4048a7 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -16,7 +16,10 @@ class TopicClassification(TextClassification): class AGNews(BasePromptsFromJsonMixin, TopicClassification): shorthand = "agn" - base_prompts_file = "evoprompt/initial_prompts/agnews/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/agnews/prompts.json", + "evoprompt/initial_prompts/agnews/prompts_auto.json", + ] def __init__(self, *args, **kwargs) -> None: super().__init__( @@ -45,7 +48,10 @@ class AGNews(BasePromptsFromJsonMixin, TopicClassification): class TREC(BasePromptsFromJsonMixin, TopicClassification): shorthand = "trec" - base_prompts_file = "evoprompt/initial_prompts/trec/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/trec/prompts.json", + "evoprompt/initial_prompts/trec/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None -- GitLab