Skip to content
Snippets Groups Projects
Commit e7e571d9 authored by Max Kimmich's avatar Max Kimmich
Browse files

Fix SST2 dataset

parent b607b76a
No related branches found
No related tags found
No related merge requests found
......@@ -64,16 +64,26 @@ class SST2(SentimentAnalysis):
):
return load_dataset(
"csv",
data_files={"validation": "evoprompt/data/sst-2/train.csv"},
data_files={"validation": "evoprompt/data/sst-2/train.tsv"},
split="validation",
sep="\t",
)
def load_test_set(self, test_dataset: str, test_split: str | None):
return load_dataset(
"csv",
data_files={"test": "evoprompt/data/sst-2/dev.csv"},
# this dataset is not in correct format therefore we process it after reading as a simple text file
dataset = load_dataset(
"text",
data_files={"test": "evoprompt/data/sst-2/test.tsv"},
split="test",
)
dataset = dataset.map(
lambda sample: {
"label": int(sample["text"][0]),
"sentence": sample["text"][2:],
},
remove_columns="text",
)
return dataset
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
return datum["sentence"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment