From 2b2251c494c0bd04d93fbffb1803df27df370862 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Mon, 19 Aug 2024 18:07:23 +0200
Subject: [PATCH] Improve code readability

---
 evoprompt/task/task.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index fc91fc5..8588373 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -289,7 +289,9 @@ class Task(metaclass=ABCMeta):
         )
 
         # get demonstration samples
-        self.demonstration_samples = self.get_demonstration_samples()
+        self.demonstration_samples, self.validation_dataset = (
+            self.get_demonstration_samples(self.validation_dataset)
+        )
 
         if self.debug and len(self.validation_dataset) > 10:
             self.validation_dataset = self.validation_dataset.shuffle(42).select(
@@ -313,23 +315,23 @@ class Task(metaclass=ABCMeta):
     def load_test_set(self, test_dataset: str, test_split: str | None):
         return load_dataset(test_dataset, split=test_split)
 
-    def get_demonstration_samples(self) -> list[DatasetDatum]:
+    def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]:
         if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
             return []
 
         # get demonstration samples from validation set
         samples_ids = self._get_demonstration_sample_ids(
-            self.validation_dataset, self.n_evaluation_demo
+            dataset, self.n_evaluation_demo
         )
         # retrieve demonstration samples from validation set
-        demonstration_samples = self.validation_dataset.filter(
+        demonstration_samples = dataset.filter(
             lambda _, idx: idx in samples_ids, with_indices=True
         )
         # remove demonstration samples from validation set
-        self.validation_dataset = self.validation_dataset.filter(
+        remaining_dataset = self.dataset.filter(
             lambda _, idx: idx not in samples_ids, with_indices=True
         )
-        return demonstration_samples
+        return demonstration_samples, remaining_dataset
 
     @abstractmethod
     def _get_demonstration_sample_ids(
-- 
GitLab