From 5ee5bb02c31814d25154d132d659439e0143a02e Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Wed, 28 Aug 2024 12:54:55 +0200
Subject: [PATCH] Allow to cache demonstration samples via datasets library

---
 evoprompt/task/sentiment_analysis.py          | 36 ++++++++++++-------
 evoprompt/task/simplification.py              |  6 ++--
 evoprompt/task/subjectivity_classification.py |  6 ++--
 evoprompt/task/summarization.py               |  6 ++--
 evoprompt/task/text_classification.py         | 10 +++++-
 evoprompt/task/topic_classification.py        | 12 ++++---
 6 files changed, 53 insertions(+), 23 deletions(-)

diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
index 012ee0b..82bbd89 100644
--- a/evoprompt/task/sentiment_analysis.py
+++ b/evoprompt/task/sentiment_analysis.py
@@ -42,10 +42,12 @@ class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis):
     def load_test_set(self, test_dataset: str, test_split: str | None):
         return super().load_test_set("stanfordnlp/sst2", "test")
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["sentence"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
 
@@ -82,10 +84,12 @@ class SST2(BasePromptsFromJsonMixin, SentimentAnalysis):
         )
         return dataset
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["sentence"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
 
@@ -114,10 +118,12 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis):
             column_names=["label", "sentence"],
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["sentence"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
     @staticmethod
@@ -146,10 +152,12 @@ class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
             "cornell-movie-review-data/rotten_tomatoes", "test"
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["text"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
 
@@ -178,10 +186,12 @@ class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
             column_names=["label", "text"],
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["text"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
 
@@ -210,8 +220,10 @@ class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
             column_names=["label", "text"],
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["text"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py
index e528e05..1d62bc4 100644
--- a/evoprompt/task/simplification.py
+++ b/evoprompt/task/simplification.py
@@ -49,8 +49,10 @@ class ASSET(BasePromptsFromJsonMixin, Simplification):
             **kwargs
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["original"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["simplifications"]
diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py
index aa5e57e..051294d 100644
--- a/evoprompt/task/subjectivity_classification.py
+++ b/evoprompt/task/subjectivity_classification.py
@@ -34,10 +34,12 @@ class Subj(BasePromptsFromJsonMixin, TextClassification):
             column_names=["label", "text"],
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["text"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
     @staticmethod
diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py
index 213131e..9d485b8 100644
--- a/evoprompt/task/summarization.py
+++ b/evoprompt/task/summarization.py
@@ -50,8 +50,10 @@ class SAMSum(BasePromptsFromJsonMixin, Summarization):
             **kwargs
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["dialogue"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["summary"]
diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py
index 2c40725..74e8dbc 100644
--- a/evoprompt/task/text_classification.py
+++ b/evoprompt/task/text_classification.py
@@ -63,9 +63,17 @@ class TextClassification(Task):
             lambda _, idx: {"idx": idx}, with_indices=True
         ).shuffle(42)
         sample_ids = []
+
+        def check_sample_label(sample, label, gold_label_for_datum_fn):
+            return gold_label_for_datum_fn(sample) == label
+
         for label in self._get_label_mapping().values():
             sample_ids_for_label = dataset_with_row_indices.filter(
-                lambda sample: self._get_gold_label_for_datum(sample) == label
+                check_sample_label,
+                fn_kwargs={
+                    "label": label,
+                    "gold_label_for_datum_fn": self._get_gold_label_for_datum,
+                },
             )[:n_evaluation_demo]["idx"]
             sample_ids += sample_ids_for_label
         return sample_ids
diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py
index e16aef3..6e9107b 100644
--- a/evoprompt/task/topic_classification.py
+++ b/evoprompt/task/topic_classification.py
@@ -28,10 +28,12 @@ class AGNews(BasePromptsFromJsonMixin, TopicClassification):
             **kwargs
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["text"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
     @staticmethod
@@ -63,10 +65,12 @@ class TREC(BasePromptsFromJsonMixin, TopicClassification):
             column_names=["label", "question"],
         )
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_text_for_datum(datum: DatasetDatum) -> str:
         return datum["question"]
 
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+    @staticmethod
+    def _get_gold_label_for_datum(datum: DatasetDatum) -> str:
         return datum["label"]
 
     @lru_cache
-- 
GitLab