Skip to content
Snippets Groups Projects
Commit 5ee5bb02 authored by Max Kimmich's avatar Max Kimmich
Browse files

Allow to cache demonstration samples via datasets library

parent 5e5d7541
No related branches found
No related tags found
No related merge requests found
......@@ -42,10 +42,12 @@ class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis):
def load_test_set(self, test_dataset: str, test_split: str | None):
return super().load_test_set("stanfordnlp/sst2", "test")
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["sentence"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
......@@ -82,10 +84,12 @@ class SST2(BasePromptsFromJsonMixin, SentimentAnalysis):
)
return dataset
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["sentence"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
......@@ -114,10 +118,12 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis):
column_names=["label", "sentence"],
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["sentence"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
@staticmethod
......@@ -146,10 +152,12 @@ class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
"cornell-movie-review-data/rotten_tomatoes", "test"
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["text"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
......@@ -178,10 +186,12 @@ class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
column_names=["label", "text"],
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["text"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
......@@ -210,8 +220,10 @@ class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
column_names=["label", "text"],
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["text"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
......@@ -49,8 +49,10 @@ class ASSET(BasePromptsFromJsonMixin, Simplification):
**kwargs
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["original"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["simplifications"]
......@@ -34,10 +34,12 @@ class Subj(BasePromptsFromJsonMixin, TextClassification):
column_names=["label", "text"],
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["text"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
@staticmethod
......
......@@ -50,8 +50,10 @@ class SAMSum(BasePromptsFromJsonMixin, Summarization):
**kwargs
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["dialogue"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["summary"]
......@@ -63,9 +63,17 @@ class TextClassification(Task):
lambda _, idx: {"idx": idx}, with_indices=True
).shuffle(42)
sample_ids = []
def check_sample_label(sample, label, gold_label_for_datum_fn):
return gold_label_for_datum_fn(sample) == label
for label in self._get_label_mapping().values():
sample_ids_for_label = dataset_with_row_indices.filter(
lambda sample: self._get_gold_label_for_datum(sample) == label
check_sample_label,
fn_kwargs={
"label": label,
"gold_label_for_datum_fn": self._get_gold_label_for_datum,
},
)[:n_evaluation_demo]["idx"]
sample_ids += sample_ids_for_label
return sample_ids
......
......@@ -28,10 +28,12 @@ class AGNews(BasePromptsFromJsonMixin, TopicClassification):
**kwargs
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["text"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
@staticmethod
......@@ -63,10 +65,12 @@ class TREC(BasePromptsFromJsonMixin, TopicClassification):
column_names=["label", "question"],
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_text_for_datum(datum: DatasetDatum) -> str:
return datum["question"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
@staticmethod
def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
return datum["label"]
@lru_cache
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment