From e7e571d91b62bfe1143ad44f6325344df35d3327 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Thu, 1 Aug 2024 12:39:05 +0200
Subject: [PATCH] Fix SST2 dataset

---
 evoprompt/task/sentiment_analysis.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
index 3ea64aa..8e9c410 100644
--- a/evoprompt/task/sentiment_analysis.py
+++ b/evoprompt/task/sentiment_analysis.py
@@ -64,16 +64,26 @@ class SST2(SentimentAnalysis):
     ):
         return load_dataset(
             "csv",
-            data_files={"validation": "evoprompt/data/sst-2/train.csv"},
+            data_files={"validation": "evoprompt/data/sst-2/train.tsv"},
             split="validation",
+            sep="\t",
         )
 
     def load_test_set(self, test_dataset: str, test_split: str | None):
-        return load_dataset(
-            "csv",
-            data_files={"test": "evoprompt/data/sst-2/dev.csv"},
+        # this dataset is not in correct format therefore we process it after reading as a simple text file
+        dataset = load_dataset(
+            "text",
+            data_files={"test": "evoprompt/data/sst-2/test.tsv"},
             split="test",
         )
+        dataset = dataset.map(
+            lambda sample: {
+                "label": int(sample["text"][0]),
+                "sentence": sample["text"][2:],
+            },
+            remove_columns="text",
+        )
+        return dataset
 
     def _get_text_for_datum(self, datum: DatasetDatum) -> str:
         return datum["sentence"]
-- 
GitLab