diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index 012ee0bd9354ba506a6e6dbb01519a1d404ee7f1..82bbd89505d6849bce30a3e9a1f389b17a7a9f64 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -42,10 +42,12 @@ class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): def load_test_set(self, test_dataset: str, test_split: str | None): return super().load_test_set("stanfordnlp/sst2", "test") - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["sentence"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @@ -82,10 +84,12 @@ class SST2(BasePromptsFromJsonMixin, SentimentAnalysis): ) return dataset - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["sentence"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @@ -114,10 +118,12 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): column_names=["label", "sentence"], ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["sentence"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @staticmethod @@ -146,10 +152,12 @@ class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): "cornell-movie-review-data/rotten_tomatoes", "test" ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["text"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @@ -178,10 +186,12 @@ class MovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): column_names=["label", "text"], ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["text"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @@ -210,8 +220,10 @@ class CustomerReviews(BasePromptsFromJsonMixin, SentimentAnalysis): column_names=["label", "text"], ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["text"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index e528e05f41220edb3e5085e22b04e9803c165ed2..1d62bc40064207750148a5e386095103ca111c5a 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -49,8 +49,10 @@ class ASSET(BasePromptsFromJsonMixin, Simplification): **kwargs ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["original"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["simplifications"] diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index aa5e57e00fbce3c7a620b3580092108714443fa1..051294d432df9abcabec557f575601e0aca2919a 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -34,10 +34,12 @@ class Subj(BasePromptsFromJsonMixin, TextClassification): column_names=["label", "text"], ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["text"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @staticmethod diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index 213131e255eefde09195519677d13805353daaa1..9d485b85fc01c5fdf4597bf5f945baf317f88932 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -50,8 +50,10 @@ class SAMSum(BasePromptsFromJsonMixin, Summarization): **kwargs ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["dialogue"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["summary"] diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py index 2c40725ee927a0605f6acad89a5c6ac8eb174f0e..74e8dbcb2d8d6ab1ae8cfb262df9c9a86848f491 100644 --- a/evoprompt/task/text_classification.py +++ b/evoprompt/task/text_classification.py @@ -63,9 +63,17 @@ class TextClassification(Task): lambda _, idx: {"idx": idx}, with_indices=True ).shuffle(42) sample_ids = [] + + def check_sample_label(sample, label, gold_label_for_datum_fn): + return gold_label_for_datum_fn(sample) == label + for label in self._get_label_mapping().values(): sample_ids_for_label = dataset_with_row_indices.filter( - lambda sample: self._get_gold_label_for_datum(sample) == label + check_sample_label, + fn_kwargs={ + "label": label, + "gold_label_for_datum_fn": self._get_gold_label_for_datum, + }, )[:n_evaluation_demo]["idx"] sample_ids += sample_ids_for_label return sample_ids diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index e16aef39ca36b1aaaed81de4471ccf4c4c79dc6c..6e9107b14c58f20eeaa58fbeb277ff4a48c490b0 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -28,10 +28,12 @@ class AGNews(BasePromptsFromJsonMixin, TopicClassification): **kwargs ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["text"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @staticmethod @@ -63,10 +65,12 @@ class TREC(BasePromptsFromJsonMixin, TopicClassification): column_names=["label", "question"], ) - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_text_for_datum(datum: DatasetDatum) -> str: return datum["question"] - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + @staticmethod + def _get_gold_label_for_datum(datum: DatasetDatum) -> str: return datum["label"] @lru_cache