From d045eee1dab05df81dfcc94addafb132d2b9159e Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Fri, 9 Aug 2024 12:10:08 +0200 Subject: [PATCH] Perform some refactoring --- evoprompt/helpers/prompts.py | 30 +++++++++++ evoprompt/optimization.py | 2 +- evoprompt/task/question_answering.py | 4 +- evoprompt/task/sentiment_analysis.py | 50 +++++-------------- evoprompt/task/simplification.py | 14 ++---- evoprompt/task/subjectivity_classification.py | 17 ++----- evoprompt/task/summarization.py | 14 ++---- evoprompt/task/task.py | 18 +++---- evoprompt/task/topic_classification.py | 27 ++-------- 9 files changed, 69 insertions(+), 107 deletions(-) create mode 100644 evoprompt/helpers/prompts.py diff --git a/evoprompt/helpers/prompts.py b/evoprompt/helpers/prompts.py new file mode 100644 index 0000000..414f9d3 --- /dev/null +++ b/evoprompt/helpers/prompts.py @@ -0,0 +1,30 @@ +import json +from pathlib import Path + + +class BasePromptsFromJsonMixin: + @staticmethod + def _load_json_file(path: str): + with Path(path).open() as json_file: + return json.load(json_file) + + @property + def base_prompts(self): + try: + base_prompts_files = getattr(self, "base_prompts_files") + base_prompts = [] + for prompt_file in base_prompts_files: + base_prompts += self._load_json_file(prompt_file) + return base_prompts + except AttributeError: + try: + base_prompts_file = getattr(self, "base_prompts_file") + return self._load_json_file(base_prompts_file) + except AttributeError: + raise Exception( + f'Class {self.__class__} does not exhibit attribute "base_prompts_files" or "base_prompts_file" which is needed for `BasePromptsFromJsonMixin`.' + ) + + +class BasePromptsFromGeneration: + pass diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 6aa71e7..1a8ef79 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -48,7 +48,7 @@ def paraphrase_prompts( or paraphrase not in paraphrases or max_tries - num_tries == n - len(paraphrases) ): - # add paraphrase only if not already present if unique_paraphrases==True + # add paraphrase only if unique_paraphrases==True and (if not already present or if the attempts run out) paraphrases.append(paraphrase) assert len(paraphrases) == n, "Requested %d paraphrases, but %d were generated." % ( diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index 55be903..352dd51 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -139,8 +139,8 @@ class QuestionAnswering(Task): @property def base_prompts(self): - # TODO find good prompt - return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + # TODO find good base prompts + return ["""In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""] class SQuAD(QuestionAnswering): diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index 0e3e827..d3697c2 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -9,6 +9,7 @@ from typing import Mapping from datasets import load_dataset from llama_cpp import LlamaGrammar +from evoprompt.helpers.prompts import BasePromptsFromJsonMixin from evoprompt.task import Task, TextClassification from evoprompt.task.task import DatasetDatum @@ -25,16 +26,10 @@ class SentimentAnalysis(TextClassification): def _get_label_mapping(self) -> Mapping: return {"negative": 0, "positive": 1} - @property - def base_prompts(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return [ - """In this task, you are given sentences from movie reviews. The task is to classify a sentence as "positive" if the sentiment of the sentence is positive or as "negative" if the sentiment of the sentence is negative.""" - ] - -class HfSST2(SentimentAnalysis): +class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "hf-sst2" + base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -51,8 +46,9 @@ class HfSST2(SentimentAnalysis): return datum["label"] -class SST2(SentimentAnalysis): +class SST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst2" + base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -86,15 +82,10 @@ class SST2(SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/sst-2/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) - -class SST5(SentimentAnalysis): +class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst5" + base_prompts_file = "evoprompt/initial_prompts/sst-5/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -125,15 +116,10 @@ class SST5(SentimentAnalysis): classes = ["terrible", "bad", "okay", "good", "great"] return dict(zip(classes, range(len(classes)))) - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/sst-5/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) - -class HfMovieReviews(SentimentAnalysis): +class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "hf-mr" + base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -154,8 +140,9 @@ class HfMovieReviews(SentimentAnalysis): return datum["label"] -class MovieReviews(SentimentAnalysis): +class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "mr" + base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -181,15 +168,10 @@ class MovieReviews(SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/mr/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) - -class CustomerReviews(SentimentAnalysis): +class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "cr" + base_prompts_file = "evoprompt/initial_prompts/cr/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -214,9 +196,3 @@ class CustomerReviews(SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/cr/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 06836aa..2cec372 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -8,6 +8,7 @@ from typing import Mapping from evaluate import load as load_metric from llama_cpp import LlamaGrammar +from evoprompt.helpers.prompts import BasePromptsFromJsonMixin from evoprompt.models import LLMModel from evoprompt.task import Task from evoprompt.task.task import DatasetDatum @@ -51,13 +52,10 @@ class Simplification(Task): def metric_name(self): return "sari" - @property - def base_prompts(self): - return ["Given the English sentence, the simplification of the sentence is"] - -class ASSET(Simplification): +class ASSET(BasePromptsFromJsonMixin, Simplification): shorthand = "asset" + base_prompts_file = "evoprompt/initial_prompts/asset/prompts.json" def __init__(self, *args, **kwargs) -> None: super().__init__( @@ -74,9 +72,3 @@ class ASSET(Simplification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["simplifications"] - - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/asset/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 144b00d..63988c6 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -5,12 +5,14 @@ from typing import Mapping from datasets import load_dataset +from evoprompt.helpers.prompts import BasePromptsFromJsonMixin from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum -class Subj(TextClassification): +class Subj(BasePromptsFromJsonMixin, TextClassification): shorthand = "subj" + base_prompts_file = "evoprompt/initial_prompts/subj/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -41,16 +43,3 @@ class Subj(TextClassification): def _get_label_mapping(self) -> Mapping: classes = ["subjective", "objective"] return dict(zip(classes, range(len(classes)))) - - @property - def base_prompts(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return [ - """In this task, you are given sentences from reviews. The task is to classify a sentence as "subjective" if the opinion of the sentence is subjective or as "objective" if the opinion of the sentence is objective.""" - ] - - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/subj/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index 4d5cee0..80664d4 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -8,6 +8,7 @@ from typing import Mapping from evaluate import load as load_metric from llama_cpp import LlamaGrammar +from evoprompt.helpers.prompts import BasePromptsFromJsonMixin from evoprompt.models import LLMModel from evoprompt.task import Task from evoprompt.task.task import DatasetDatum @@ -52,13 +53,10 @@ class Summarization(Task): def metric_name(self): return "rougeL" - @property - def base_prompts(self): - return ["Please summarize the main context."] - -class SAMSum(Summarization): +class SAMSum(BasePromptsFromJsonMixin, Summarization): shorthand = "sams" + base_prompts_file = "evoprompt/initial_prompts/sam/prompts.json" def __init__(self, *args, **kwargs) -> None: super().__init__( @@ -75,9 +73,3 @@ class SAMSum(Summarization): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["summary"] - - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/sam/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index ff125b6..58d8c77 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -273,7 +273,16 @@ class Task(metaclass=ABCMeta): self.model = model self.debug = debug self.use_grammar = use_grammar + self.evaluation_strategy = get_evaluation_strategy(evaluation_strategy)(self) + logger.info( + f"using evaluation strategy: {self.evaluation_strategy.__class__.__name__}" + ) + if hasattr(self.evaluation_strategy, "early_stopping"): + logger.info( + f"using early stopping: {self.evaluation_strategy.early_stopping}", + ) + self.use_evolution_demo = use_evolution_demo self.validation_dataset = self.load_validation_set( @@ -361,15 +370,6 @@ class Task(metaclass=ABCMeta): dataset_iterator = self.evaluation_strategy.get_dataset_iterator( dataset, parent_histories ) - - logger.info( - f"using evaluation strategy: {self.evaluation_strategy.__class__.__name__}" - ) - if hasattr(self.evaluation_strategy, "early_stopping"): - logger.info( - f"using early stopping: {self.evaluation_strategy.early_stopping}", - ) - results: list = [] evaluation_usage = ModelUsage() evaluation_history = [] diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index 261858b..aecec3e 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -5,13 +5,15 @@ from typing import Mapping from datasets import load_dataset +from evoprompt.helpers.prompts import BasePromptsFromJsonMixin from evoprompt.models import LLMModel from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum -class AGNews(TextClassification): +class AGNews(BasePromptsFromJsonMixin, TextClassification): shorthand = "agn" + base_prompts_file = "evoprompt/initial_prompts/agnews/prompts.json" def __init__(self, *args, **kwargs) -> None: super().__init__( @@ -34,16 +36,10 @@ class AGNews(TextClassification): classes = ["world", "sports", "business", "tech"] return dict(zip(classes, range(len(classes)))) - @property - def base_prompts(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return [ - """In this task, you are given a news article. Your task is to classify the article to one out of the four topics "World", "Sports", "Business", "Tech" if the article"s main topic is relevant to the world, sports, business, and technology, correspondingly. If you are not sure about the topic, choose the closest option.""" - ] - -class TREC(TextClassification): +class TREC(BasePromptsFromJsonMixin, TextClassification): shorthand = "trec" + base_prompts_file = "evoprompt/initial_prompts/trec/prompts.json" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -73,16 +69,3 @@ class TREC(TextClassification): def _get_label_mapping(self) -> Mapping: classes = ["description", "entity", "expression", "human", "location", "number"] return dict(zip(classes, range(len(classes)))) - - @property - def base_prompts(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return [ - """You are given a question. You need to detect which category better describes the question. Answer with "Description", "Entity", "Expression", "Human", "Location", and "Number".""" - ] - - @property - def base_prompts(self): - initial_prompts_file = Path("evoprompt/initial_prompts/trec/prompts.json") - with initial_prompts_file.open() as json_file: - return json.load(json_file) -- GitLab