Skip to content
Snippets Groups Projects
Commit 9f5d05bb authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

fix problem when setting 0 demonstration exaples

parent 9462d643
No related branches found
No related tags found
1 merge request!5Llm as a judge
from evoprompt.cli import argument_parser
from evoprompt.models import LLMModel
from evoprompt.task.question_answering import QuestionAnswering
from evoprompt.task.sentiment_analysis import SentimentAnalysis
from evoprompt.task.simplification import ASSET, Simplification
from evoprompt.task.subjectivity_classification import Subj
from evoprompt.task.summarization import SAMSum, Summarization
# make sure to run definitions of subclasses of Task first
from evoprompt.task.task import EvaluationStrategyKey, Task
from evoprompt.task.question_answering import QuestionAnswering
from evoprompt.task.text_classification import TextClassification
from evoprompt.task.sentiment_analysis import SentimentAnalysis
from evoprompt.task.topic_classification import AGNews, TREC
from evoprompt.task.subjectivity_classification import Subj
from evoprompt.task.text_generation import TextGeneration
from evoprompt.task.topic_classification import TREC, AGNews
from evoprompt.task.summarization import Summarization, SAMSum
from evoprompt.task.simplification import Simplification, ASSET
from evoprompt.utils import get_all_subclasses
# at this point, we assume that all subclasses of Task have been defined
......@@ -30,7 +31,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options):
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_group = argument_parser.add_argument_group("Task arguments")
argument_group.add_argument(
"--task", type=str, required=True, choices=sorted(tasks.keys())
"--task", type=str, required=True, choices=sorted(tasks.keys())
)
argument_group.add_argument("--use-grammar", "-g", action="store_true")
argument_group.add_argument(
......
......@@ -315,9 +315,11 @@ class Task(metaclass=ABCMeta):
def load_test_set(self, test_dataset: str, test_split: str | None):
return load_dataset(test_dataset, split=test_split)
def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]:
def get_demonstration_samples(
self, dataset: Dataset
) -> tuple[list[DatasetDatum], list[DatasetDatum]]:
if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
return []
return [], dataset
# get demonstration samples from validation set
samples_ids = self._get_demonstration_sample_ids(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment