diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index fc91fc5ca573b3ac20b90a2585ef8195f849f074..8588373da5a40568cabf481639f66687025b1bc0 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -289,7 +289,9 @@ class Task(metaclass=ABCMeta): ) # get demonstration samples - self.demonstration_samples = self.get_demonstration_samples() + self.demonstration_samples, self.validation_dataset = ( + self.get_demonstration_samples(self.validation_dataset) + ) if self.debug and len(self.validation_dataset) > 10: self.validation_dataset = self.validation_dataset.shuffle(42).select( @@ -313,23 +315,23 @@ class Task(metaclass=ABCMeta): def load_test_set(self, test_dataset: str, test_split: str | None): return load_dataset(test_dataset, split=test_split) - def get_demonstration_samples(self) -> list[DatasetDatum]: + def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]: if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0: return [] # get demonstration samples from validation set samples_ids = self._get_demonstration_sample_ids( - self.validation_dataset, self.n_evaluation_demo + dataset, self.n_evaluation_demo ) # retrieve demonstration samples from validation set - demonstration_samples = self.validation_dataset.filter( + demonstration_samples = dataset.filter( lambda _, idx: idx in samples_ids, with_indices=True ) # remove demonstration samples from validation set - self.validation_dataset = self.validation_dataset.filter( + remaining_dataset = self.dataset.filter( lambda _, idx: idx not in samples_ids, with_indices=True ) - return demonstration_samples + return demonstration_samples, remaining_dataset @abstractmethod def _get_demonstration_sample_ids(