From 9f5d05bb9acb9943a0ceded643f5f82ba247e18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Tue, 20 Aug 2024 10:57:57 +0200 Subject: [PATCH] fix problem when setting 0 demonstration exaples --- evoprompt/task/__init__.py | 15 ++++++++------- evoprompt/task/task.py | 6 ++++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index e628c1a..aa750c7 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -1,16 +1,17 @@ from evoprompt.cli import argument_parser from evoprompt.models import LLMModel -from evoprompt.task.question_answering import QuestionAnswering -from evoprompt.task.sentiment_analysis import SentimentAnalysis -from evoprompt.task.simplification import ASSET, Simplification -from evoprompt.task.subjectivity_classification import Subj -from evoprompt.task.summarization import SAMSum, Summarization # make sure to run definitions of subclasses of Task first from evoprompt.task.task import EvaluationStrategyKey, Task +from evoprompt.task.question_answering import QuestionAnswering from evoprompt.task.text_classification import TextClassification +from evoprompt.task.sentiment_analysis import SentimentAnalysis +from evoprompt.task.topic_classification import AGNews, TREC +from evoprompt.task.subjectivity_classification import Subj from evoprompt.task.text_generation import TextGeneration -from evoprompt.task.topic_classification import TREC, AGNews +from evoprompt.task.summarization import Summarization, SAMSum +from evoprompt.task.simplification import Simplification, ASSET + from evoprompt.utils import get_all_subclasses # at this point, we assume that all subclasses of Task have been defined @@ -30,7 +31,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument( - "--task", type=str, required=True, choices=sorted(tasks.keys()) + "--task", type=str, required=True, choices=sorted(tasks.keys()) ) argument_group.add_argument("--use-grammar", "-g", action="store_true") argument_group.add_argument( diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index d93a66e..35c59c5 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -315,9 +315,11 @@ class Task(metaclass=ABCMeta): def load_test_set(self, test_dataset: str, test_split: str | None): return load_dataset(test_dataset, split=test_split) - def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]: + def get_demonstration_samples( + self, dataset: Dataset + ) -> tuple[list[DatasetDatum], list[DatasetDatum]]: if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0: - return [] + return [], dataset # get demonstration samples from validation set samples_ids = self._get_demonstration_sample_ids( -- GitLab