From 3d76757df47798afc705cdbe45b292091f7cd6dd Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Wed, 16 Oct 2024 19:57:20 +0200 Subject: [PATCH] Re-add generated base prompts --- evoprompt/models.py | 8 ++- evoprompt/opt_types.py | 24 +++++++++ evoprompt/optimization.py | 38 +++++--------- ...{base_prompts_mixin.py => base_prompts.py} | 52 ++++++++++++------- evoprompt/task/question_answering.py | 12 ++--- evoprompt/task/sentiment_analysis.py | 2 +- evoprompt/task/simplification.py | 2 +- evoprompt/task/subjectivity_classification.py | 2 +- evoprompt/task/summarization.py | 2 +- evoprompt/task/task.py | 5 +- evoprompt/task/topic_classification.py | 2 +- 11 files changed, 92 insertions(+), 57 deletions(-) rename evoprompt/task/{base_prompts_mixin.py => base_prompts.py} (53%) diff --git a/evoprompt/models.py b/evoprompt/models.py index 9566f15..4b0b4ab 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -427,8 +427,12 @@ class ChatModel: use_randomness=use_randomness, ) - messages.append(self._get_assistant_message(reponse)) - return reponse, history, messages, usage + return ( + reponse, + history, + messages + [self._get_assistant_message(reponse)], + usage, + ) class LlamaChat(ChatModel, Llama): diff --git a/evoprompt/opt_types.py b/evoprompt/opt_types.py index a80c850..b700da4 100644 --- a/evoprompt/opt_types.py +++ b/evoprompt/opt_types.py @@ -1,5 +1,6 @@ import json from dataclasses import dataclass, field, is_dataclass +from typing import Literal, NamedTuple, TypedDict from uuid import uuid4 @@ -25,6 +26,29 @@ class ModelUsage: ) +PromptSource = Literal[ + "baseprompt", + "baseprompt_file", + "baseprompt_gen", + "paraphrase", + "evolution", + "corrected", +] + + +class Judgement(NamedTuple): + original_response: str + corrected_response: str + happy: bool | None + skip: bool + + +class PromptMeta(TypedDict): + gen: int + source: PromptSource + judgements: list[Judgement] + + @dataclass(frozen=True) class Prompt: content: str diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index de7ffb5..db40b63 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict from difflib import Differ from pathlib import Path -from typing import Any, Literal, NamedTuple, Optional, TypedDict +from typing import Any, Optional import wandb from rich.panel import Panel @@ -18,7 +18,13 @@ from tqdm import tqdm, trange from evoprompt.cli import argument_parser from evoprompt.models import ChatMessages, LLMModel -from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt +from evoprompt.opt_types import ( + Judgement, + ModelUsage, + OptTypeEncoder, + Prompt, + PromptMeta, +) from evoprompt.task import Task from evoprompt.utils import initialize_run_directory, log_calls @@ -27,21 +33,6 @@ logger = logging.getLogger(__name__) PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.""" -PromptSource = Literal["baseprompt", "paraphrase", "evolution", "corrected"] - - -class Judgement(NamedTuple): - original_response: str - corrected_response: str - happy: bool | None - skip: bool - - -class PromptMeta(TypedDict): - gen: int - source: PromptSource - judgements: list[Judgement] - class ResponseEditor(App): BINDINGS = [ @@ -174,7 +165,7 @@ class PromptOptimization: def get_initial_prompts(self, num_initial_prompts: int, debug: bool = False): # this implements the para_topk algorithm from https://github.com/beeevita/EvoPrompt - base_prompts = self.task.base_prompts + base_prompts, base_prompts_sources = self.task.base_prompts if debug: base_prompts = base_prompts[:2] @@ -187,15 +178,14 @@ class PromptOptimization: # take at most half of the best prompts sorted_results = sorted( - zip(evaluation_results, base_prompts), + zip(evaluation_results, base_prompts, base_prompts_sources), key=lambda x: x[0][0], # sort by score reverse=True, # best first ) - top_prompts = [ - prompt for _, prompt in sorted_results[: num_initial_prompts // 2] - ] - initial_population = top_prompts.copy() - prompt_sources = ["baseprompt" for _ in initial_population] + sorted_results = sorted_results[: num_initial_prompts // 2] + _, top_prompts, prompt_sources = zip(*sorted_results) + initial_population = list(top_prompts) + prompt_sources = list(prompt_sources) # fill up the rest with paraphrases of the top prompts promptindex_to_paraphrase = 0 diff --git a/evoprompt/task/base_prompts_mixin.py b/evoprompt/task/base_prompts.py similarity index 53% rename from evoprompt/task/base_prompts_mixin.py rename to evoprompt/task/base_prompts.py index c4d00cd..2eb9ad7 100644 --- a/evoprompt/task/base_prompts_mixin.py +++ b/evoprompt/task/base_prompts.py @@ -10,7 +10,7 @@ from evoprompt.utils import get_rng class BasePromptsFromJsonMixin: @staticmethod - def _load_json_file(path: str): + def _load_json_file(path: str) -> list[str]: with Path(path).open() as json_file: return json.load(json_file) @@ -20,41 +20,48 @@ class BasePromptsFromJsonMixin: raise Exception( f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`." ) - base_prompts = [] + prompts, sources = super().base_prompts + prompts_from_files = [] for prompt_file in self.base_prompts_files: - base_prompts += self._load_json_file(prompt_file) - return base_prompts + prompts_from_files += self._load_json_file(prompt_file) + prompts += prompts_from_files + sources += ["baseprompt_file"] * len(prompts_from_files) + return prompts, sources -class BasePromptsFromGeneration: +class BasePromptsFromGenerationMixin: def __init__(self, *args, **kwargs) -> None: self.evolution_model: LLMModel = kwargs.get("evolution_model") super().__init__(*args, **kwargs) # this implements the initial population generation from Zhou et al., 2023: Large Language Models are Human-Level Prompt Engineers + # patience allows to stop the generation process if no new prompts can be generated + # can be set to -1 to generate as many prompts as needed (but can possibly run forever) def generate_prompt( self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False - ) -> str: + ) -> list[str]: self.validation_dataset: Dataset samples = self.validation_dataset.shuffle(42).select( get_rng().choice(len(self.validation_dataset), 5, replace=False) ) - prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n" - raise NotImplementedError( - "The prompt needs to be adapted for the model taking into account the correct format." + prompt = "I gave a friend a single instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n\n" + prompt += "\n".join( + f"Input:\n{self._get_prompt_text_for_datum(sample)}\nOutput:\n{self._get_gold_label_generation_for_datum(sample)}\n" + for sample in samples ) - prompt = self.build_demonstration_prompt(samples, prompt=prompt) prompt += "\nThe instruction was " system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs." - input(prompt) + messages = [ + self.evolution_model._get_user_message(prompt) + ] # , self.evolution_model._get_assistant_message("The instruction was ")] generated_prompts = [] while len(generated_prompts) < num_prompts: response, _, _, _ = self.evolution_model.create_completion( system_message=system_message, - prompt=prompt, + messages=messages, + use_randomness=True, ) - input(response) matches = re.findall( # regex that extracts anything within tags <instruction> and optional </instruction> rf"<instruction>(.+?)(?:(?=</instruction>)|$)", @@ -62,9 +69,9 @@ class BasePromptsFromGeneration: flags=re.IGNORECASE, ) if matches: - prompt = matches[-1].strip() - if allow_duplicates or prompt not in generated_prompts: - generated_prompts.append(matches[-1].strip()) + generated_prompt = matches[-1].strip() + if allow_duplicates or generated_prompt not in generated_prompts: + generated_prompts.append(generated_prompt) else: if patience == 0: break @@ -74,6 +81,15 @@ class BasePromptsFromGeneration: @property def base_prompts(self): - num_prompts = getattr(self, "num_generated_base_prompts", 0) + if not hasattr(self, "num_generated_base_prompts"): + raise AttributeError( + f"{self.__class__} must expose attribute `num_generated_base_prompts`" + ) + prompts, sources = super().base_prompts + + num_prompts = self.num_generated_base_prompts + generated_prompts = self.generate_prompt(num_prompts) + prompts += generated_prompts + sources += ["baseprompt_gen"] * len(generated_prompts) - return self.generate_prompt(num_prompts) + return prompts, sources diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index 6266c86..fddf81a 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -8,7 +8,7 @@ from datasets import Dataset from evaluate import load as load_metric from llama_cpp import LlamaGrammar -from evoprompt.task.base_prompts_mixin import BasePromptsFromGeneration +from evoprompt.task.base_prompts import BasePromptsFromGenerationMixin from evoprompt.opt_types import ModelUsage from evoprompt.task.task import DatasetDatum, Task from evoprompt.utils import get_rng @@ -172,7 +172,7 @@ class QuestionAnswering(Task): return "f1" -class SQuAD(BasePromptsFromGeneration, QuestionAnswering): +class SQuAD(BasePromptsFromGenerationMixin, QuestionAnswering): shorthand = "squad" num_generated_base_prompts = 10 @@ -202,7 +202,7 @@ class SQuAD(BasePromptsFromGeneration, QuestionAnswering): @property def base_prompts(self): - generated_base_prompts = super().base_prompts - return [ - "In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context." - ] + generated_base_prompts + prompts, sources = super().base_prompts + prompts.append("In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.") + sources.append("baseprompt") + return prompts, sources diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index bf72c17..89b092a 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -4,7 +4,7 @@ from typing import Mapping from datasets import load_dataset -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 37e0ec5..3a2e6aa 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -2,7 +2,7 @@ import logging from evaluate import load as load_metric -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 7c3882e..10fea8b 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -3,7 +3,7 @@ from typing import Mapping from datasets import load_dataset -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index fed21c1..4084d27 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -2,7 +2,7 @@ import logging from evaluate import load as load_metric -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 3b59db3..90f128b 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -11,7 +11,7 @@ from llama_cpp import LlamaGrammar from tqdm import tqdm from evoprompt.models import ChatMessage, ChatMessages, LLMModel -from evoprompt.opt_types import ModelUsage +from evoprompt.opt_types import ModelUsage, PromptSource from evoprompt.utils import log_calls logger = logging.getLogger(__name__) @@ -492,4 +492,5 @@ class Task(metaclass=ABCMeta): @property @abstractmethod - def base_prompts(self) -> list[str]: ... + def base_prompts(self) -> tuple[list[str], list[PromptSource]]: + return [], [] diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index dd1905f..6230f85 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -3,7 +3,7 @@ from typing import Mapping from datasets import load_dataset -from evoprompt.task.base_prompts_mixin import BasePromptsFromJsonMixin +from evoprompt.task.base_prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum -- GitLab