diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index e5d76733a521a181e2356375c3dc51d05eaedbb7..0e3e827f88581e1bebae9e0bfe84f82bdd3db3d2 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -1,7 +1,9 @@ +import json import logging from abc import abstractmethod from argparse import Namespace from functools import lru_cache +from pathlib import Path from typing import Mapping from datasets import load_dataset @@ -84,6 +86,12 @@ class SST2(SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/sst-2/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) + class SST5(SentimentAnalysis): shorthand = "sst5" @@ -119,10 +127,9 @@ class SST5(SentimentAnalysis): @property def base_prompts(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return [ - "In this task, you are given sentences from movie reviews. Based on the given review, classify it to one of the five classes: (1) terrible, (2) bad, (3) okay, (4) good, and (5) great." - ] + initial_prompts_file = Path("evoprompt/initial_prompts/sst-5/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) class HfMovieReviews(SentimentAnalysis): @@ -174,6 +181,12 @@ class MovieReviews(SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/mr/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) + class CustomerReviews(SentimentAnalysis): shorthand = "cr" @@ -201,3 +214,9 @@ class CustomerReviews(SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] + + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/cr/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 0af6a411bd80eaa64f86ff53e8beebc8e54a6f89..06836aa4ab5b2b3a5c1dcecb572c867bc4730619 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -1,6 +1,8 @@ +import json import logging from abc import abstractmethod from functools import lru_cache +from pathlib import Path from typing import Mapping from evaluate import load as load_metric @@ -72,3 +74,9 @@ class ASSET(Simplification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["simplifications"] + + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/asset/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index eabf7022a4b96e651114eb1d92b83688fccdba45..144b00d9813820b97637b3d3fa574627e339f25e 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -1,4 +1,6 @@ +import json from functools import lru_cache +from pathlib import Path from typing import Mapping from datasets import load_dataset @@ -46,3 +48,9 @@ class Subj(TextClassification): return [ """In this task, you are given sentences from reviews. The task is to classify a sentence as "subjective" if the opinion of the sentence is subjective or as "objective" if the opinion of the sentence is objective.""" ] + + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/subj/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index cb4e0a784a5b4bd474ab4eb7c13dc680fba7fc43..4d5cee00fe56d99c5a07f7ba36590ab6929cc585 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -1,6 +1,8 @@ +import json import logging from abc import abstractmethod from functools import lru_cache +from pathlib import Path from typing import Mapping from evaluate import load as load_metric @@ -73,3 +75,9 @@ class SAMSum(Summarization): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["summary"] + + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/sam/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file) diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index 58ba5a742a28c8c81a269fa0c29e44daa2e391ef..261858b5fa483a7f201b74418637e7d2d6702638 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -1,4 +1,6 @@ +import json from functools import lru_cache +from pathlib import Path from typing import Mapping from datasets import load_dataset @@ -78,3 +80,9 @@ class TREC(TextClassification): return [ """You are given a question. You need to detect which category better describes the question. Answer with "Description", "Entity", "Expression", "Human", "Location", and "Number".""" ] + + @property + def base_prompts(self): + initial_prompts_file = Path("evoprompt/initial_prompts/trec/prompts.json") + with initial_prompts_file.open() as json_file: + return json.load(json_file)