From 3d76757df47798afc705cdbe45b292091f7cd6dd Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Wed, 16 Oct 2024 19:57:20 +0200
Subject: [PATCH] Re-add generated base prompts

---
 evoprompt/models.py                           |  8 ++-
 evoprompt/opt_types.py                        | 24 +++++++++
 evoprompt/optimization.py                     | 38 +++++---------
 ...{base_prompts_mixin.py => base_prompts.py} | 52 ++++++++++++-------
 evoprompt/task/question_answering.py          | 12 ++---
 evoprompt/task/sentiment_analysis.py          |  2 +-
 evoprompt/task/simplification.py              |  2 +-
 evoprompt/task/subjectivity_classification.py |  2 +-
 evoprompt/task/summarization.py               |  2 +-
 evoprompt/task/task.py                        |  5 +-
 evoprompt/task/topic_classification.py        |  2 +-
 11 files changed, 92 insertions(+), 57 deletions(-)
 rename evoprompt/task/{base_prompts_mixin.py => base_prompts.py} (53%)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 9566f15..4b0b4ab 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -427,8 +427,12 @@ class ChatModel:
             use_randomness=use_randomness,
         )
 
-        messages.append(self._get_assistant_message(reponse))
-        return reponse, history, messages, usage
+        return (
+            reponse,
+            history,
+            messages + [self._get_assistant_message(reponse)],
+            usage,
+        )
 
 
 class LlamaChat(ChatModel, Llama):
diff --git a/evoprompt/opt_types.py b/evoprompt/opt_types.py
index a80c850..b700da4 100644
--- a/evoprompt/opt_types.py
+++ b/evoprompt/opt_types.py
@@ -1,5 +1,6 @@
 import json
 from dataclasses import dataclass, field, is_dataclass
+from typing import Literal, NamedTuple, TypedDict
 from uuid import uuid4
 
 
@@ -25,6 +26,29 @@ class ModelUsage:
         )
 
 
+PromptSource = Literal[
+    "baseprompt",
+    "baseprompt_file",
+    "baseprompt_gen",
+    "paraphrase",
+    "evolution",
+    "corrected",
+]
+
+
+class Judgement(NamedTuple):
+    original_response: str
+    corrected_response: str
+    happy: bool | None
+    skip: bool
+
+
+class PromptMeta(TypedDict):
+    gen: int
+    source: PromptSource
+    judgements: list[Judgement]
+
+
 @dataclass(frozen=True)
 class Prompt:
     content: str
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index de7ffb5..db40b63 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -4,7 +4,7 @@ import re
 from collections import OrderedDict
 from difflib import Differ
 from pathlib import Path
-from typing import Any, Literal, NamedTuple, Optional, TypedDict
+from typing import Any, Optional
 
 import wandb
 from rich.panel import Panel
@@ -18,7 +18,13 @@ from tqdm import tqdm, trange
 
 from evoprompt.cli import argument_parser
 from evoprompt.models import ChatMessages, LLMModel
-from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt
+from evoprompt.opt_types import (
+    Judgement,
+    ModelUsage,
+    OptTypeEncoder,
+    Prompt,
+    PromptMeta,
+)
 from evoprompt.task import Task
 from evoprompt.utils import initialize_run_directory, log_calls
 
@@ -27,21 +33,6 @@ logger = logging.getLogger(__name__)
 
 PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>."""
 
-PromptSource = Literal["baseprompt", "paraphrase", "evolution", "corrected"]
-
-
-class Judgement(NamedTuple):
-    original_response: str
-    corrected_response: str
-    happy: bool | None
-    skip: bool
-
-
-class PromptMeta(TypedDict):
-    gen: int
-    source: PromptSource
-    judgements: list[Judgement]
-
 
 class ResponseEditor(App):
     BINDINGS = [
@@ -174,7 +165,7 @@ class PromptOptimization:
 
     def get_initial_prompts(self, num_initial_prompts: int, debug: bool = False):
         # this implements the para_topk algorithm from https://github.com/beeevita/EvoPrompt
-        base_prompts = self.task.base_prompts
+        base_prompts, base_prompts_sources = self.task.base_prompts
         if debug:
             base_prompts = base_prompts[:2]
 
@@ -187,15 +178,14 @@ class PromptOptimization:
 
         # take at most half of the best prompts
         sorted_results = sorted(
-            zip(evaluation_results, base_prompts),
+            zip(evaluation_results, base_prompts, base_prompts_sources),
             key=lambda x: x[0][0],  # sort by score
             reverse=True,  # best first
         )
-        top_prompts = [
-            prompt for _, prompt in sorted_results[: num_initial_prompts // 2]
-        ]
-        initial_population = top_prompts.copy()
-        prompt_sources = ["baseprompt" for _ in initial_population]
+        sorted_results = sorted_results[: num_initial_prompts // 2]
+        _, top_prompts, prompt_sources = zip(*sorted_results)
+        initial_population = list(top_prompts)
+        prompt_sources = list(prompt_sources)
 
         # fill up the rest with paraphrases of the top prompts
         promptindex_to_paraphrase = 0
diff --git a/evoprompt/task/base_prompts_mixin.py b/evoprompt/task/base_prompts.py
similarity index 53%
rename from evoprompt/task/base_prompts_mixin.py
rename to evoprompt/task/base_prompts.py
index c4d00cd..2eb9ad7 100644
--- a/evoprompt/task/base_prompts_mixin.py
+++ b/evoprompt/task/base_prompts.py
@@ -10,7 +10,7 @@ from evoprompt.utils import get_rng
 
 class BasePromptsFromJsonMixin:
     @staticmethod
-    def _load_json_file(path: str):
+    def _load_json_file(path: str) -> list[str]:
         with Path(path).open() as json_file:
             return json.load(json_file)
 
@@ -20,41 +20,48 @@ class BasePromptsFromJsonMixin:
             raise Exception(
                 f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`."
             )
-        base_prompts = []
+        prompts, sources = super().base_prompts
+        prompts_from_files = []
         for prompt_file in self.base_prompts_files:
-            base_prompts += self._load_json_file(prompt_file)
-        return base_prompts
+            prompts_from_files += self._load_json_file(prompt_file)
+        prompts += prompts_from_files
+        sources += ["baseprompt_file"] * len(prompts_from_files)
+        return prompts, sources
 
 
-class BasePromptsFromGeneration:
+class BasePromptsFromGenerationMixin:
     def __init__(self, *args, **kwargs) -> None:
         self.evolution_model: LLMModel = kwargs.get("evolution_model")
         super().__init__(*args, **kwargs)
 
     # this implements the initial population generation from Zhou et al., 2023: Large Language Models are Human-Level Prompt Engineers
+    # patience allows to stop the generation process if no new prompts can be generated
+    # can be set to -1 to generate as many prompts as needed (but can possibly run forever)
     def generate_prompt(
         self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False
-    ) -> str:
+    ) -> list[str]:
         self.validation_dataset: Dataset
         samples = self.validation_dataset.shuffle(42).select(
             get_rng().choice(len(self.validation_dataset), 5, replace=False)
         )
-        prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n"
-        raise NotImplementedError(
-            "The prompt needs to be adapted for the model taking into account the correct format."
+        prompt = "I gave a friend a single instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n\n"
+        prompt += "\n".join(
+            f"Input:\n{self._get_prompt_text_for_datum(sample)}\nOutput:\n{self._get_gold_label_generation_for_datum(sample)}\n"
+            for sample in samples
         )
-        prompt = self.build_demonstration_prompt(samples, prompt=prompt)
         prompt += "\nThe instruction was "
         system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs."
-        input(prompt)
+        messages = [
+            self.evolution_model._get_user_message(prompt)
+        ]  # , self.evolution_model._get_assistant_message("The instruction was ")]
 
         generated_prompts = []
         while len(generated_prompts) < num_prompts:
             response, _, _, _ = self.evolution_model.create_completion(
                 system_message=system_message,
-                prompt=prompt,
+                messages=messages,
+                use_randomness=True,
             )
-            input(response)
             matches = re.findall(
                 # regex that extracts anything within tags <instruction> and optional </instruction>
                 rf"<instruction>(.+?)(?:(?=</instruction>)|$)",
@@ -62,9 +69,9 @@ class BasePromptsFromGeneration:
                 flags=re.IGNORECASE,
             )
             if matches:
-                prompt = matches[-1].strip()
-                if allow_duplicates or prompt not in generated_prompts:
-                    generated_prompts.append(matches[-1].strip())
+                generated_prompt = matches[-1].strip()
+                if allow_duplicates or generated_prompt not in generated_prompts:
+                    generated_prompts.append(generated_prompt)
             else:
                 if patience == 0:
                     break
@@ -74,6 +81,15 @@ class BasePromptsFromGeneration:
 
     @property
     def base_prompts(self):
-        num_prompts = getattr(self, "num_generated_base_prompts", 0)
+        if not hasattr(self, "num_generated_base_prompts"):
+            raise AttributeError(
+                f"{self.__class__} must expose attribute `num_generated_base_prompts`"
+            )
+        prompts, sources = super().base_prompts
+
+        num_prompts = self.num_generated_base_prompts
+        generated_prompts = self.generate_prompt(num_prompts)
+        prompts += generated_prompts
+        sources += ["baseprompt_gen"] * len(generated_prompts)
 
-        return self.generate_prompt(num_prompts)
+        return prompts, sources
diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py
index 6266c86..fddf81a 100644
--- a/evoprompt/task/question_answering.py
+++ b/evoprompt/task/question_answering.py
@@ -8,7 +8,7 @@ from datasets import Dataset
 from evaluate import load as load_metric
 from llama_cpp import LlamaGrammar
 
-from evoprompt.task.base_prompts_mixin import BasePromptsFromGeneration
+from evoprompt.task.base_prompts import BasePromptsFromGenerationMixin
 from evoprompt.opt_types import ModelUsage
 from evoprompt.task.task import DatasetDatum, Task
 from evoprompt.utils import get_rng
@@ -172,7 +172,7 @@ class QuestionAnswering(Task):
         return "f1"
 
 
-class SQuAD(BasePromptsFromGeneration, QuestionAnswering):
+class SQuAD(BasePromptsFromGenerationMixin, QuestionAnswering):
     shorthand = "squad"
     num_generated_base_prompts = 10
 
@@ -202,7 +202,7 @@ class SQuAD(BasePromptsFromGeneration, QuestionAnswering):
 
     @property
     def base_prompts(self):
-        generated_base_prompts = super().base_prompts
-        return [
-            "In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."
-        ] + generated_base_prompts
+        prompts, sources = super().base_prompts
+        prompts.append("In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.")
+        sources.append("baseprompt")
+        return prompts, sources
diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
index bf72c17..89b092a 100644
--- a/evoprompt/task/sentiment_analysis.py
+++ b/evoprompt/task/sentiment_analysis.py
@@ -4,7 +4,7 @@ from typing import Mapping
 
 from datasets import load_dataset
 
-from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin
+from evoprompt.task.base_prompts import BasePromptsFromJsonMixin
 from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py
index 37e0ec5..3a2e6aa 100644
--- a/evoprompt/task/simplification.py
+++ b/evoprompt/task/simplification.py
@@ -2,7 +2,7 @@ import logging
 
 from evaluate import load as load_metric
 
-from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin
+from evoprompt.task.base_prompts import BasePromptsFromJsonMixin
 from evoprompt.task import TextGeneration
 from evoprompt.task.task import DatasetDatum
 
diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py
index 7c3882e..10fea8b 100644
--- a/evoprompt/task/subjectivity_classification.py
+++ b/evoprompt/task/subjectivity_classification.py
@@ -3,7 +3,7 @@ from typing import Mapping
 
 from datasets import load_dataset
 
-from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin
+from evoprompt.task.base_prompts import BasePromptsFromJsonMixin
 from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py
index fed21c1..4084d27 100644
--- a/evoprompt/task/summarization.py
+++ b/evoprompt/task/summarization.py
@@ -2,7 +2,7 @@ import logging
 
 from evaluate import load as load_metric
 
-from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin
+from evoprompt.task.base_prompts import BasePromptsFromJsonMixin
 from evoprompt.task import TextGeneration
 from evoprompt.task.task import DatasetDatum
 
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 3b59db3..90f128b 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -11,7 +11,7 @@ from llama_cpp import LlamaGrammar
 from tqdm import tqdm
 
 from evoprompt.models import ChatMessage, ChatMessages, LLMModel
-from evoprompt.opt_types import ModelUsage
+from evoprompt.opt_types import ModelUsage, PromptSource
 from evoprompt.utils import log_calls
 
 logger = logging.getLogger(__name__)
@@ -492,4 +492,5 @@ class Task(metaclass=ABCMeta):
 
     @property
     @abstractmethod
-    def base_prompts(self) -> list[str]: ...
+    def base_prompts(self) -> tuple[list[str], list[PromptSource]]: 
+        return [], []
diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py
index dd1905f..6230f85 100644
--- a/evoprompt/task/topic_classification.py
+++ b/evoprompt/task/topic_classification.py
@@ -3,7 +3,7 @@ from typing import Mapping
 
 from datasets import load_dataset
 
-from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin
+from evoprompt.task.base_prompts import BasePromptsFromJsonMixin
 from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
-- 
GitLab