Skip to content
Snippets Groups Projects
Commit 662a11e8 authored by Max Kimmich's avatar Max Kimmich
Browse files

Simpllify BasePromptsFromJsonMixin and add auto prompts

parent da2cc6ed
No related branches found
No related tags found
No related merge requests found
......@@ -10,20 +10,14 @@ class BasePromptsFromJsonMixin:
@property
def base_prompts(self):
try:
base_prompts_files = getattr(self, "base_prompts_files")
base_prompts = []
for prompt_file in base_prompts_files:
base_prompts += self._load_json_file(prompt_file)
return base_prompts
except AttributeError:
try:
base_prompts_file = getattr(self, "base_prompts_file")
return self._load_json_file(base_prompts_file)
except AttributeError:
raise Exception(
f'Class {self.__class__} does not exhibit attribute "base_prompts_files" or "base_prompts_file" which is needed for `BasePromptsFromJsonMixin`.'
)
if not hasattr(self, "base_prompts_files"):
raise Exception(
f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`."
)
base_prompts = []
for prompt_file in self.base_prompts_files:
base_prompts += self._load_json_file(prompt_file)
return base_prompts
class BasePromptsFromGeneration:
......
......@@ -29,7 +29,10 @@ class SentimentAnalysis(TextClassification):
class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis):
shorthand = "sst2-hf"
base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/sst-2/prompts.json",
"evoprompt/initial_prompts/sst-2/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......@@ -48,7 +51,10 @@ class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis):
class SST2(BasePromptsFromJsonMixin, SentimentAnalysis):
shorthand = "sst2"
base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/sst-2/prompts.json",
"evoprompt/initial_prompts/sst-2/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......@@ -85,7 +91,10 @@ class SST2(BasePromptsFromJsonMixin, SentimentAnalysis):
class SST5(BasePromptsFromJsonMixin, SentimentAnalysis):
shorthand = "sst5"
base_prompts_file = "evoprompt/initial_prompts/sst-5/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/sst-5/prompts.json",
"evoprompt/initial_prompts/sst-5/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......@@ -120,7 +129,10 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis):
class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
shorthand = "mr-hf"
base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/mr/prompts.json",
"evoprompt/initial_prompts/mr/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......@@ -143,7 +155,10 @@ class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
shorthand = "mr"
base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/mr/prompts.json",
"evoprompt/initial_prompts/mr/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......@@ -172,7 +187,10 @@ class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
shorthand = "cr"
base_prompts_file = "evoprompt/initial_prompts/cr/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/cr/prompts.json",
"evoprompt/initial_prompts/cr/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......
......@@ -34,7 +34,10 @@ class Simplification(TextGeneration):
class ASSET(BasePromptsFromJsonMixin, Simplification):
shorthand = "asset"
base_prompts_file = "evoprompt/initial_prompts/asset/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/asset/prompts.json",
"evoprompt/initial_prompts/asset/prompts_auto.json",
]
def __init__(self, *args, **kwargs) -> None:
super().__init__(
......
......@@ -10,7 +10,10 @@ from evoprompt.task.task import DatasetDatum
class Subj(BasePromptsFromJsonMixin, TextClassification):
shorthand = "subj"
base_prompts_file = "evoprompt/initial_prompts/subj/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/subj/prompts.json",
"evoprompt/initial_prompts/subj/prompts_auto.json",
]
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
......
......@@ -35,7 +35,10 @@ class Summarization(TextGeneration):
class SAMSum(BasePromptsFromJsonMixin, Summarization):
shorthand = "sams"
base_prompts_file = "evoprompt/initial_prompts/sam/prompts.json"
base_prompts_files = [
"evoprompt/initial_prompts/sam/prompts.json",
"evoprompt/initial_prompts/sam/prompts_auto.json",
]
def __init__(self, *args, **kwargs) -> None:
super().__init__(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment