From 662a11e8f343323fb2cf7c2057d61e24d2bf349b Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Thu, 22 Aug 2024 15:31:44 +0200 Subject: [PATCH] Simpllify BasePromptsFromJsonMixin and add auto prompts --- evoprompt/helpers/prompts.py | 22 +++++--------- evoprompt/task/sentiment_analysis.py | 30 +++++++++++++++---- evoprompt/task/simplification.py | 5 +++- evoprompt/task/subjectivity_classification.py | 5 +++- evoprompt/task/summarization.py | 5 +++- 5 files changed, 44 insertions(+), 23 deletions(-) diff --git a/evoprompt/helpers/prompts.py b/evoprompt/helpers/prompts.py index 414f9d3..df1c725 100644 --- a/evoprompt/helpers/prompts.py +++ b/evoprompt/helpers/prompts.py @@ -10,20 +10,14 @@ class BasePromptsFromJsonMixin: @property def base_prompts(self): - try: - base_prompts_files = getattr(self, "base_prompts_files") - base_prompts = [] - for prompt_file in base_prompts_files: - base_prompts += self._load_json_file(prompt_file) - return base_prompts - except AttributeError: - try: - base_prompts_file = getattr(self, "base_prompts_file") - return self._load_json_file(base_prompts_file) - except AttributeError: - raise Exception( - f'Class {self.__class__} does not exhibit attribute "base_prompts_files" or "base_prompts_file" which is needed for `BasePromptsFromJsonMixin`.' - ) + if not hasattr(self, "base_prompts_files"): + raise Exception( + f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`." + ) + base_prompts = [] + for prompt_file in self.base_prompts_files: + base_prompts += self._load_json_file(prompt_file) + return base_prompts class BasePromptsFromGeneration: diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index e0126a0..012ee0b 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -29,7 +29,10 @@ class SentimentAnalysis(TextClassification): class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst2-hf" - base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sst-2/prompts.json", + "evoprompt/initial_prompts/sst-2/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -48,7 +51,10 @@ class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): class SST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst2" - base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sst-2/prompts.json", + "evoprompt/initial_prompts/sst-2/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -85,7 +91,10 @@ class SST2(BasePromptsFromJsonMixin, SentimentAnalysis): class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst5" - base_prompts_file = "evoprompt/initial_prompts/sst-5/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sst-5/prompts.json", + "evoprompt/initial_prompts/sst-5/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -120,7 +129,10 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "mr-hf" - base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/mr/prompts.json", + "evoprompt/initial_prompts/mr/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -143,7 +155,10 @@ class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "mr" - base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/mr/prompts.json", + "evoprompt/initial_prompts/mr/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -172,7 +187,10 @@ class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "cr" - base_prompts_file = "evoprompt/initial_prompts/cr/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/cr/prompts.json", + "evoprompt/initial_prompts/cr/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 69231a8..e528e05 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -34,7 +34,10 @@ class Simplification(TextGeneration): class ASSET(BasePromptsFromJsonMixin, Simplification): shorthand = "asset" - base_prompts_file = "evoprompt/initial_prompts/asset/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/asset/prompts.json", + "evoprompt/initial_prompts/asset/prompts_auto.json", + ] def __init__(self, *args, **kwargs) -> None: super().__init__( diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 5db5edc..aa5e57e 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -10,7 +10,10 @@ from evoprompt.task.task import DatasetDatum class Subj(BasePromptsFromJsonMixin, TextClassification): shorthand = "subj" - base_prompts_file = "evoprompt/initial_prompts/subj/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/subj/prompts.json", + "evoprompt/initial_prompts/subj/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index 3d4789a..213131e 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -35,7 +35,10 @@ class Summarization(TextGeneration): class SAMSum(BasePromptsFromJsonMixin, Summarization): shorthand = "sams" - base_prompts_file = "evoprompt/initial_prompts/sam/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sam/prompts.json", + "evoprompt/initial_prompts/sam/prompts_auto.json", + ] def __init__(self, *args, **kwargs) -> None: super().__init__( -- GitLab