Skip to content
Snippets Groups Projects
Commit 2b2251c4 authored by Max Kimmich's avatar Max Kimmich
Browse files

Improve code readability

parent 4030ec32
No related branches found
No related tags found
1 merge request!3Add demonstration data for evaluation and change answer extraction for text classification tasks
...@@ -289,7 +289,9 @@ class Task(metaclass=ABCMeta): ...@@ -289,7 +289,9 @@ class Task(metaclass=ABCMeta):
) )
# get demonstration samples # get demonstration samples
self.demonstration_samples = self.get_demonstration_samples() self.demonstration_samples, self.validation_dataset = (
self.get_demonstration_samples(self.validation_dataset)
)
if self.debug and len(self.validation_dataset) > 10: if self.debug and len(self.validation_dataset) > 10:
self.validation_dataset = self.validation_dataset.shuffle(42).select( self.validation_dataset = self.validation_dataset.shuffle(42).select(
...@@ -313,23 +315,23 @@ class Task(metaclass=ABCMeta): ...@@ -313,23 +315,23 @@ class Task(metaclass=ABCMeta):
def load_test_set(self, test_dataset: str, test_split: str | None): def load_test_set(self, test_dataset: str, test_split: str | None):
return load_dataset(test_dataset, split=test_split) return load_dataset(test_dataset, split=test_split)
def get_demonstration_samples(self) -> list[DatasetDatum]: def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]:
if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0: if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
return [] return []
# get demonstration samples from validation set # get demonstration samples from validation set
samples_ids = self._get_demonstration_sample_ids( samples_ids = self._get_demonstration_sample_ids(
self.validation_dataset, self.n_evaluation_demo dataset, self.n_evaluation_demo
) )
# retrieve demonstration samples from validation set # retrieve demonstration samples from validation set
demonstration_samples = self.validation_dataset.filter( demonstration_samples = dataset.filter(
lambda _, idx: idx in samples_ids, with_indices=True lambda _, idx: idx in samples_ids, with_indices=True
) )
# remove demonstration samples from validation set # remove demonstration samples from validation set
self.validation_dataset = self.validation_dataset.filter( remaining_dataset = self.dataset.filter(
lambda _, idx: idx not in samples_ids, with_indices=True lambda _, idx: idx not in samples_ids, with_indices=True
) )
return demonstration_samples return demonstration_samples, remaining_dataset
@abstractmethod @abstractmethod
def _get_demonstration_sample_ids( def _get_demonstration_sample_ids(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment