Skip to content
Snippets Groups Projects
Commit 2b2251c4 authored by Max Kimmich's avatar Max Kimmich
Browse files

Improve code readability

parent 4030ec32
Branches
No related tags found
1 merge request!3Add demonstration data for evaluation and change answer extraction for text classification tasks
......@@ -289,7 +289,9 @@ class Task(metaclass=ABCMeta):
)
# get demonstration samples
self.demonstration_samples = self.get_demonstration_samples()
self.demonstration_samples, self.validation_dataset = (
self.get_demonstration_samples(self.validation_dataset)
)
if self.debug and len(self.validation_dataset) > 10:
self.validation_dataset = self.validation_dataset.shuffle(42).select(
......@@ -313,23 +315,23 @@ class Task(metaclass=ABCMeta):
def load_test_set(self, test_dataset: str, test_split: str | None):
return load_dataset(test_dataset, split=test_split)
def get_demonstration_samples(self) -> list[DatasetDatum]:
def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]:
if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
return []
# get demonstration samples from validation set
samples_ids = self._get_demonstration_sample_ids(
self.validation_dataset, self.n_evaluation_demo
dataset, self.n_evaluation_demo
)
# retrieve demonstration samples from validation set
demonstration_samples = self.validation_dataset.filter(
demonstration_samples = dataset.filter(
lambda _, idx: idx in samples_ids, with_indices=True
)
# remove demonstration samples from validation set
self.validation_dataset = self.validation_dataset.filter(
remaining_dataset = self.dataset.filter(
lambda _, idx: idx not in samples_ids, with_indices=True
)
return demonstration_samples
return demonstration_samples, remaining_dataset
@abstractmethod
def _get_demonstration_sample_ids(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment