From 9f5d05bb9acb9943a0ceded643f5f82ba247e18c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Tue, 20 Aug 2024 10:57:57 +0200
Subject: [PATCH] fix problem when setting 0 demonstration exaples

---
 evoprompt/task/__init__.py | 15 ++++++++-------
 evoprompt/task/task.py     |  6 ++++--
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py
index e628c1a..aa750c7 100644
--- a/evoprompt/task/__init__.py
+++ b/evoprompt/task/__init__.py
@@ -1,16 +1,17 @@
 from evoprompt.cli import argument_parser
 from evoprompt.models import LLMModel
-from evoprompt.task.question_answering import QuestionAnswering
-from evoprompt.task.sentiment_analysis import SentimentAnalysis
-from evoprompt.task.simplification import ASSET, Simplification
-from evoprompt.task.subjectivity_classification import Subj
-from evoprompt.task.summarization import SAMSum, Summarization
 
 # make sure to run definitions of subclasses of Task first
 from evoprompt.task.task import EvaluationStrategyKey, Task
+from evoprompt.task.question_answering import QuestionAnswering
 from evoprompt.task.text_classification import TextClassification
+from evoprompt.task.sentiment_analysis import SentimentAnalysis
+from evoprompt.task.topic_classification import AGNews, TREC
+from evoprompt.task.subjectivity_classification import Subj
 from evoprompt.task.text_generation import TextGeneration
-from evoprompt.task.topic_classification import TREC, AGNews
+from evoprompt.task.summarization import Summarization, SAMSum
+from evoprompt.task.simplification import Simplification, ASSET
+
 from evoprompt.utils import get_all_subclasses
 
 # at this point, we assume that all subclasses of Task have been defined
@@ -30,7 +31,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options):
 argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
 argument_group = argument_parser.add_argument_group("Task arguments")
 argument_group.add_argument(
-    "--task", type=str, required=True, choices=sorted(tasks.keys())
+    "--task",  type=str, required=True, choices=sorted(tasks.keys())
 )
 argument_group.add_argument("--use-grammar", "-g", action="store_true")
 argument_group.add_argument(
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index d93a66e..35c59c5 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -315,9 +315,11 @@ class Task(metaclass=ABCMeta):
     def load_test_set(self, test_dataset: str, test_split: str | None):
         return load_dataset(test_dataset, split=test_split)
 
-    def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]:
+    def get_demonstration_samples(
+        self, dataset: Dataset
+    ) -> tuple[list[DatasetDatum], list[DatasetDatum]]:
         if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
-            return []
+            return [], dataset
 
         # get demonstration samples from validation set
         samples_ids = self._get_demonstration_sample_ids(
-- 
GitLab