From e7e571d91b62bfe1143ad44f6325344df35d3327 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Thu, 1 Aug 2024 12:39:05 +0200 Subject: [PATCH] Fix SST2 dataset --- evoprompt/task/sentiment_analysis.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index 3ea64aa..8e9c410 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -64,16 +64,26 @@ class SST2(SentimentAnalysis): ): return load_dataset( "csv", - data_files={"validation": "evoprompt/data/sst-2/train.csv"}, + data_files={"validation": "evoprompt/data/sst-2/train.tsv"}, split="validation", + sep="\t", ) def load_test_set(self, test_dataset: str, test_split: str | None): - return load_dataset( - "csv", - data_files={"test": "evoprompt/data/sst-2/dev.csv"}, + # this dataset is not in correct format therefore we process it after reading as a simple text file + dataset = load_dataset( + "text", + data_files={"test": "evoprompt/data/sst-2/test.tsv"}, split="test", ) + dataset = dataset.map( + lambda sample: { + "label": int(sample["text"][0]), + "sentence": sample["text"][2:], + }, + remove_columns="text", + ) + return dataset def _get_text_for_datum(self, datum: DatasetDatum) -> str: return datum["sentence"] -- GitLab