From 31c10c46c9bc286abf999ee092a2d5763c18945c Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Tue, 3 Sep 2024 17:11:41 +0200
Subject: [PATCH] Fix topic classification and question answering tasks

---
 evoprompt/evolution.py                 | 412 -------------------------
 evoprompt/evolution/evolution.py       |   4 +-
 evoprompt/task/topic_classification.py |  10 +-
 3 files changed, 11 insertions(+), 415 deletions(-)
 delete mode 100644 evoprompt/evolution.py

diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py
deleted file mode 100644
index 0bf36db..0000000
--- a/evoprompt/evolution.py
+++ /dev/null
@@ -1,412 +0,0 @@
-import logging
-import re
-from abc import ABCMeta, abstractmethod
-from typing import Any
-
-from tqdm import trange
-
-from evoprompt.cli import argument_parser
-from evoprompt.evolution.template_de import get_de_prompt_template
-from evoprompt.models import LLMModel
-from evoprompt.opt_types import ModelUsage, Prompt
-from evoprompt.optimization import Judgement, PromptOptimization
-from evoprompt.task import Task
-from evoprompt.utils import get_all_subclasses, get_rng, log_calls
-
-logger = logging.getLogger(__name__)
-
-
-SYSTEM_MESSAGE = "Please carefully follow the instruction step-by-step."
-
-GA_PROMPT = """
-1. Cross over the following prompts and generate a new prompt:
-Prompt 1: {prompt1}
-Prompt 2: {prompt2}
-2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>.
-"""
-
-
-DE_COT_PROMPTS = [
-    "Step 1: Identify the main different parts between the Prompt 1 and Prompt 2:\nPrompt 1: {prompt1}\nPrompt 2: {prompt2}",
-    "Step 2: Randomly mutate the different parts",
-    "Step 3: Combine the different parts with Prompt 3, selectively replace it with the different parts in Step 2 and generate a new prompt.\nPrompt 3: {prompt3}",
-    "Step 4: Cross over the prompt in the Step 3 with the following basic prompt and generate a final prompt bracketed with <prompt> and </prompt>:\nBasic Prompt: {basic_prompt}",
-]
-
-
-class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
-    shorthand: str
-
-    """The super class for all evolution algorithms containing shared parameters."""
-
-    def __init__(
-        self,
-        population_size: int,
-        *,
-        task: Task,
-        evolution_model: LLMModel,
-        evaluation_model: LLMModel,
-        judge_model: LLMModel | None,
-        run_options: dict[str, Any] = {},
-    ) -> None:
-        super().__init__(
-            task=task,
-            evolution_model=evolution_model,
-            evaluation_model=evaluation_model,
-            judge_model=judge_model,
-            run_options=run_options,
-        )
-        self.use_evolution_demo = run_options.get("use_evolution_demo", False)
-
-        self.population_size = population_size
-
-    @log_calls("Performing selection")
-    def select(self, prompts: list[Prompt]):
-        # In GA, two parent solutions are normally selected based on the roulette wheel
-        # selection method according to the fitness value (Lipowski & Lipowska, 2012).
-        # Similar to this, we utilize the roulette wheel selection method to select
-        # two parent prompts in the current population according to the scores evaluated
-        # on development sets. Specifically, let si denote the performance score on the
-        # development set of the i-th prompt in the population, which contains a total
-        # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
-        # pi = si / Σj=1->N sj.
-        # add small value to avoid zero chance of selection for some prompts
-        scores = [prompt.score + 1e-6 for prompt in prompts]
-        selection_probabilities = [score / sum(scores) for score in scores]
-        return get_rng().choice(
-            prompts, size=2, replace=False, p=selection_probabilities
-        )
-
-    @abstractmethod
-    def evolve(
-        self,
-        prompt_1: str,
-        prompt_2: str,
-        *,
-        prompts_current_evolution: list[Prompt],
-        current_iteration: int,
-    ) -> tuple[str, list[Judgement], ModelUsage]:
-        pass
-
-    @abstractmethod
-    def update(self, *args, **kwargs):
-        pass
-
-    def run(self, num_iterations: int, debug: bool = False) -> None:
-        # debug mode for quick run
-        if debug:
-            self.population_size = 3
-            num_iterations = 2
-
-        self.init_run(self.population_size, num_iterations, debug=debug)
-
-        # Algorithm 1 Discrete prompt optimization: EVOPROMPT
-
-        # Line 2:
-        for t in self.iterations_pbar:
-            # Line 3: Selection: select a certain number of prompts from current population as parent prompts
-            # pr1,...,prk ∼ Pt−1
-            prompts_current_evolution = self.P[t - 1]
-
-            new_evolutions = []
-
-            for i in trange(self.population_size, desc="updates", leave=False):
-                # for both GA and DE we start with two parent prompts
-                pr1, pr2 = self.select(self.P[t - 1])
-
-                # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
-                # p′i ←Evo(pr1,...,prk)
-                (
-                    p_i,
-                    judgements,
-                    evolution_usage,
-                ) = self.evolve(
-                    pr1,
-                    pr2,
-                    prompts_current_evolution=prompts_current_evolution,
-                    current_iteration=i,
-                )
-                self.total_evolution_usage += evolution_usage
-
-                prompt_source = (
-                    "corrected" if not all(j.happy for j in judgements) else "generated"
-                )
-                evolved_prompt = self.add_prompt(
-                    p_i,
-                    parents=(pr1, pr2),
-                    meta={"gen": t, "source": prompt_source, "judgements": judgements},
-                )
-                self.total_evaluation_usage += evolved_prompt.usage
-
-                new_evolutions.append(evolved_prompt)
-                self.save_snapshot()
-            # Line 6: Update based on the evaluation scores
-            # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
-            new_population = self.update(new_evolutions, prompts_current_evolution)
-
-            # store new generation
-            self.P.append(new_population)
-            self.save_snapshot()
-
-        self.save_snapshot()
-        # Line 8: Return the best prompt, p∗, among the final population PT :
-        # p∗ ← argmaxp∈PT f(p, D)
-        p = max(self.P[-1], key=lambda prompt: self.all_prompts[prompt.id].score)
-        logger.info("Best prompt with score %.2f: %s", p.score, p)
-
-        # We pick the prompt with the highest score on the development set and report its score on the testset.
-        test_performance, _, _ = self.task.evaluate_test(p.content)
-        logger.info(
-            "Best prompt on test set: %s %s", test_performance, self.task.metric_name
-        )
-        logger.info(
-            "Usage (evolution model / evaluation model / total): %s / %s / %s",
-            self.total_evolution_usage,
-            self.total_evaluation_usage,
-            self.total_evolution_usage + self.total_evaluation_usage,
-        )
-
-        return self.total_evolution_usage, self.total_evaluation_usage
-
-
-class GeneticAlgorithm(EvolutionAlgorithm):
-    """The genetic algorithm (GA) implemented using LLMs."""
-
-    shorthand = "ga"
-
-    # kwargs is just there for convenience, as evolve function of other optimizers might have different inputs
-    # @register_action(ignore_args=["kwargs"])
-    @log_calls("Performing prompt evolution using GA")
-    def evolve(
-        self,
-        prompt_1: str,
-        prompt_2: str,
-        **kwargs,
-    ):
-        # Following the evolutionary operators in GA, a new candidate prompt is generated through
-        # a two-step process based on the selected two parents:
-        # 1) The parent prompts undergo crossover, resulting in a new prompt that
-        #   selectively combines components from both parents;
-        # 2) The newly generated prompt from the first step undergoes mutation,
-        #   in which random alterations are made to some of its content.
-        # Based on this two-step process, we design instructions, guiding LLMs to
-        # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
-
-        filled_prompt = GA_PROMPT.format(prompt1=prompt_1, prompt2=prompt_2)
-        evolved_prompt, messages, usage = self.evolution_model.create_completion(
-            system_message=SYSTEM_MESSAGE,
-            prompt=filled_prompt,
-        )
-
-        judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages)
-        evolved_prompt = judgement.corrected_response
-
-        if "<prompt>" in evolved_prompt:
-            evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
-
-        logger.info(
-            "GA-evolved prompts '%s' and '%s' into '%s'",
-            prompt_1,
-            prompt_2,
-            evolved_prompt,
-        )
-
-        return evolved_prompt, [judgement], usage
-
-    @log_calls("Performing update for GA")
-    def update(
-        self, prompts_current_evolution: list[Prompt], new_evolutions: list[Prompt]
-    ):
-        # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt
-        # using a development set, denoted as D, to obtain a score that quantifies the
-        # quality of the prompt. We consider a straightforward selection strategy.
-        # Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts,
-        # which are combined with the current population of N prompts.
-        # The updated population is then selected by retaining the N prompts with the highest scores.
-        retained_prompts: list[Prompt] = []
-        min_retained_score = 0
-        for prompt in prompts_current_evolution + new_evolutions:
-            if len(retained_prompts) < self.population_size:
-                retained_prompts.append(prompt)
-                min_retained_score = min(min_retained_score, prompt.score)
-            elif prompt.score > min_retained_score:
-                retained_prompts.sort(key=lambda p: p.score)
-                retained_prompts[0] = prompt
-
-        return retained_prompts
-
-
-class DifferentialEvolution(EvolutionAlgorithm):
-    """The differential algorithm (DE) implemented using LLMs."""
-
-    shorthand = "de"
-
-    @log_calls("Performing prompt evolution using DE")
-    def evolve(
-        self,
-        prompt_1: str,
-        prompt_2: str,
-        *,
-        prompts_current_evolution: list[Prompt],
-        current_iteration: int,
-    ):
-        # TODO add description from paper
-
-        # DE needs best prompt for evolution
-        best_prompt_current_evolution = max(
-            prompts_current_evolution, key=lambda prompt: prompt.score
-        )
-
-        filled_prompt = get_de_prompt_template(
-            self.use_evolution_demo, self.task
-        ).format(
-            prompt1=prompt_1,
-            prompt2=prompt_2,
-            prompt3=best_prompt_current_evolution,
-            basic_prompt=prompts_current_evolution[current_iteration],
-        )
-        evolved_prompt, messages, usage = self.evolution_model.create_completion(
-            system_message=SYSTEM_MESSAGE,
-            prompt=filled_prompt,
-        )
-
-        judgement = self.judge_and_correct_step(filled_prompt, evolved_prompt, messages)
-        evolved_prompt = judgement.corrected_response
-
-        matches = re.findall(
-            # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing
-            r"<prompt>(?!.*<prompt>)(?:(.*)</prompt>|(.*))",
-            evolved_prompt,
-            flags=(re.IGNORECASE | re.DOTALL),
-        )
-        if matches and any(matches[0]):
-            # there is always only a single match, and one group should be non-empty
-            if matches[0][0]:
-                evolved_prompt = matches[0][0]
-            else:
-                assert matches[0][1]
-                evolved_prompt = matches[0][1]
-        else:
-            # TODO what to do in this case? Discard generated prompt directly?
-            pass
-
-        logger.info(
-            "DE-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
-            prompt_1,
-            prompt_2,
-            best_prompt_current_evolution,
-            prompts_current_evolution[current_iteration],
-            evolved_prompt,
-        )
-
-        return evolved_prompt, [judgement], usage
-
-    @log_calls("Performing update for DE")
-    def update(
-        self, prompts_current_evolution: list[Prompt], new_evolutions: list[Prompt]
-    ):
-        # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
-        assert len(prompts_current_evolution) == len(new_evolutions)
-        population = [
-            (new_prompt if new_prompt.score > current_prompt.score else current_prompt)
-            for current_prompt, new_prompt in zip(
-                prompts_current_evolution, new_evolutions
-            )
-        ]
-        return population
-
-
-class DifferentialEvolutionWithCot(DifferentialEvolution):
-    """The differential algorithm using Chain-of-Thought (DE-CoT) implemented using LLMs."""
-
-    shorthand = "de-cot"
-
-    @log_calls("Performing prompt evolution using DE-CoT")
-    def evolve(
-        self,
-        prompt_1: str,
-        prompt_2: str,
-        *,
-        prompts_current_evolution: list[Prompt],
-        current_iteration: int,
-    ):
-        # TODO add description
-
-        # DE needs best prompt for evolution
-        best_prompt_current_evolution = max(
-            prompts_current_evolution, key=lambda prompt: prompt.score
-        )
-
-        messages = None
-        response: str = ""
-        judgements: list[Judgement] = []
-        usage: ModelUsage = ModelUsage()
-        for idx, prompt in enumerate(DE_COT_PROMPTS):
-            filled_prompt = prompt.format(
-                prompt1=prompt_1,
-                prompt2=prompt_2,
-                prompt3=best_prompt_current_evolution,
-                basic_prompt=prompts_current_evolution[current_iteration],
-            )
-            response, messages, usage = self.evolution_model.create_completion(
-                system_message=SYSTEM_MESSAGE,
-                prompt=filled_prompt,
-                history=messages,
-                stop="</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None,
-            )
-            logger.debug(
-                "Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s",
-                idx,
-                messages,
-                response,
-            )
-            judgement = self.judge_and_correct_step(
-                filled_prompt, response, history=messages
-            )
-            judgements.append(judgement)
-            # replace last message with corrected response
-            messages[-1]["content"] = judgement.corrected_response
-            response = judgement.corrected_response
-
-        # at this point we should get a new prompt
-        if "<prompt>" in response:
-            response = response.split("<prompt>")[1].split("</prompt>")[0]
-
-        logger.info(
-            "DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
-            prompt_1,
-            prompt_2,
-            best_prompt_current_evolution,
-            prompts_current_evolution[current_iteration],
-            response,
-        )
-
-        return response, judgements, usage
-
-
-def get_all_subclasses(cls):
-    return set(cls.__subclasses__()).union(
-        [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
-    )
-
-
-optimizers = {
-    algorithm.shorthand: algorithm
-    for algorithm in get_all_subclasses(EvolutionAlgorithm)
-}
-
-
-def get_optimizer_class(name: str) -> type[EvolutionAlgorithm]:
-    if name not in optimizers:
-        raise ValueError("Optimization Algorithm %s does not exist", name)
-    return optimizers[name]
-
-
-argument_parser.add_argument(
-    "--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga"
-)
-argument_parser.add_argument(
-    "--use-evolution-demo",
-    action="store_true",
-    help="Whether to prepend a single demonstration example for evolution or not",
-)
diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py
index eb969f8..60a9cbf 100644
--- a/evoprompt/evolution/evolution.py
+++ b/evoprompt/evolution/evolution.py
@@ -336,7 +336,9 @@ class DifferentialEvolution(EvolutionAlgorithm):
 
     def get_prompt_template(self):
         if self.use_evolution_demo:
-            if isinstance(self.task, (TextClassification, Summarization)):
+            if isinstance(
+                self.task, (TextClassification, Summarization, QuestionAnswering)
+            ):
                 return get_demonstration_prompt_template(
                     DE_PROMPT, DE_DEMONSTRATION_DATA_SIM
                 )
diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py
index 6e9107b..c4048a7 100644
--- a/evoprompt/task/topic_classification.py
+++ b/evoprompt/task/topic_classification.py
@@ -16,7 +16,10 @@ class TopicClassification(TextClassification):
 
 class AGNews(BasePromptsFromJsonMixin, TopicClassification):
     shorthand = "agn"
-    base_prompts_file = "evoprompt/initial_prompts/agnews/prompts.json"
+    base_prompts_files = [
+        "evoprompt/initial_prompts/agnews/prompts.json",
+        "evoprompt/initial_prompts/agnews/prompts_auto.json",
+    ]
 
     def __init__(self, *args, **kwargs) -> None:
         super().__init__(
@@ -45,7 +48,10 @@ class AGNews(BasePromptsFromJsonMixin, TopicClassification):
 
 class TREC(BasePromptsFromJsonMixin, TopicClassification):
     shorthand = "trec"
-    base_prompts_file = "evoprompt/initial_prompts/trec/prompts.json"
+    base_prompts_files = [
+        "evoprompt/initial_prompts/trec/prompts.json",
+        "evoprompt/initial_prompts/trec/prompts_auto.json",
+    ]
 
     def load_validation_set(
         self, validation_dataset: str | None, validation_split: str | None
-- 
GitLab