diff --git a/evoprompt/helpers/prompts.py b/evoprompt/helpers/prompts.py index 414f9d3659912aced77af34879eeab4ad049bb44..df1c725e0d275943916a22d7b7411a37b6e7b839 100644 --- a/evoprompt/helpers/prompts.py +++ b/evoprompt/helpers/prompts.py @@ -10,20 +10,14 @@ class BasePromptsFromJsonMixin: @property def base_prompts(self): - try: - base_prompts_files = getattr(self, "base_prompts_files") - base_prompts = [] - for prompt_file in base_prompts_files: - base_prompts += self._load_json_file(prompt_file) - return base_prompts - except AttributeError: - try: - base_prompts_file = getattr(self, "base_prompts_file") - return self._load_json_file(base_prompts_file) - except AttributeError: - raise Exception( - f'Class {self.__class__} does not exhibit attribute "base_prompts_files" or "base_prompts_file" which is needed for `BasePromptsFromJsonMixin`.' - ) + if not hasattr(self, "base_prompts_files"): + raise Exception( + f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`." + ) + base_prompts = [] + for prompt_file in self.base_prompts_files: + base_prompts += self._load_json_file(prompt_file) + return base_prompts class BasePromptsFromGeneration: diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index e0126a0055837b3d4905839358d031fd6bf3814b..012ee0bd9354ba506a6e6dbb01519a1d404ee7f1 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -29,7 +29,10 @@ class SentimentAnalysis(TextClassification): class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst2-hf" - base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sst-2/prompts.json", + "evoprompt/initial_prompts/sst-2/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -48,7 +51,10 @@ class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): class SST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst2" - base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sst-2/prompts.json", + "evoprompt/initial_prompts/sst-2/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -85,7 +91,10 @@ class SST2(BasePromptsFromJsonMixin, SentimentAnalysis): class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst5" - base_prompts_file = "evoprompt/initial_prompts/sst-5/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sst-5/prompts.json", + "evoprompt/initial_prompts/sst-5/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -120,7 +129,10 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "mr-hf" - base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/mr/prompts.json", + "evoprompt/initial_prompts/mr/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -143,7 +155,10 @@ class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "mr" - base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/mr/prompts.json", + "evoprompt/initial_prompts/mr/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -172,7 +187,10 @@ class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "cr" - base_prompts_file = "evoprompt/initial_prompts/cr/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/cr/prompts.json", + "evoprompt/initial_prompts/cr/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 69231a80fe7464405f7382a1a82e4d9eb6a360fd..e528e05f41220edb3e5085e22b04e9803c165ed2 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -34,7 +34,10 @@ class Simplification(TextGeneration): class ASSET(BasePromptsFromJsonMixin, Simplification): shorthand = "asset" - base_prompts_file = "evoprompt/initial_prompts/asset/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/asset/prompts.json", + "evoprompt/initial_prompts/asset/prompts_auto.json", + ] def __init__(self, *args, **kwargs) -> None: super().__init__( diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 5db5edccc5b8f2f2951628276c6a925d4d909411..aa5e57e00fbce3c7a620b3580092108714443fa1 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -10,7 +10,10 @@ from evoprompt.task.task import DatasetDatum class Subj(BasePromptsFromJsonMixin, TextClassification): shorthand = "subj" - base_prompts_file = "evoprompt/initial_prompts/subj/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/subj/prompts.json", + "evoprompt/initial_prompts/subj/prompts_auto.json", + ] def load_validation_set( self, validation_dataset: str | None, validation_split: str | None diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index 3d4789a5c84c27300e25115606c0149124f23d54..213131e255eefde09195519677d13805353daaa1 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -35,7 +35,10 @@ class Summarization(TextGeneration): class SAMSum(BasePromptsFromJsonMixin, Summarization): shorthand = "sams" - base_prompts_file = "evoprompt/initial_prompts/sam/prompts.json" + base_prompts_files = [ + "evoprompt/initial_prompts/sam/prompts.json", + "evoprompt/initial_prompts/sam/prompts_auto.json", + ] def __init__(self, *args, **kwargs) -> None: super().__init__(