diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 4348cbb2cec3fa97c4f6e3077722ae22d2ad18e0..5bb2f3cba9758e246c16397d86250508084bfced 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -1,11 +1,9 @@ import logging -import shelve from abc import ABCMeta, abstractmethod from collections import deque from dataclasses import KW_ONLY, dataclass -from pathlib import Path from statistics import mean -from typing import Iterable, Literal, Union +from typing import Any, Iterable, Literal from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar @@ -265,6 +263,7 @@ class Task(metaclass=ABCMeta): evaluation_strategy: EvaluationStrategyKey, validation_split: str | None = None, use_evolution_demo: bool = False, + n_evaluation_demo: int | None = None, test_split: str | None = None, debug: bool = False, **kwargs, @@ -272,6 +271,7 @@ class Task(metaclass=ABCMeta): self.model = model self.debug = debug self.use_grammar = use_grammar + self.n_evaluation_demo = n_evaluation_demo self.evaluation_strategy = get_evaluation_strategy(evaluation_strategy)(self) logger.info( @@ -287,6 +287,10 @@ class Task(metaclass=ABCMeta): self.validation_dataset = self.load_validation_set( validation_dataset, validation_split ) + + # get demonstration samples + self.demonstration_samples = self.get_demonstration_samples() + if self.debug and len(self.validation_dataset) > 10: self.validation_dataset = self.validation_dataset.shuffle(42).select( range(10) @@ -309,13 +313,38 @@ class Task(metaclass=ABCMeta): def load_test_set(self, test_dataset: str, test_split: str | None): return load_dataset(test_dataset, split=test_split) + def get_demonstration_samples(self) -> list[DatasetDatum]: + if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0: + return [] + + # get demonstration samples from validation set + samples_ids = self._get_demonstration_sample_ids( + self.validation_dataset, self.n_evaluation_demo + ) + # retrieve demonstration samples from validation set + demonstration_samples = self.validation_dataset.filter( + lambda _, idx: idx in samples_ids, with_indices=True + ) + # remove demonstration samples from validation set + self.validation_dataset = self.validation_dataset.filter( + lambda _, idx: idx not in samples_ids, with_indices=True + ) + return demonstration_samples + + @abstractmethod + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> list[Any]: + pass + def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, _, usage = self.model( system_message=SYSTEM_MESSAGE, prompt=prompt, - prompt_appendix=self._get_prompt_text_for_datum(datum), + prompt_suffix="\n" + self._get_prompt_text_for_datum(datum), + prompt_appendix="\nResponse: ", # grammar can be applied to constrain the model output grammar=self._get_grammar(datum) if self.use_grammar else None, ) @@ -344,9 +373,10 @@ class Task(metaclass=ABCMeta): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: pass + @abstractmethod + # This method is needed for the demonstration examples. def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str: - # This method is needed for the demonstration example. - return self._get_gold_label_for_datum(datum) + pass @abstractmethod def _aggregate_result(self, results: list) -> float: @@ -366,6 +396,14 @@ class Task(metaclass=ABCMeta): evaluation_usage = ModelUsage() evaluation_history = [] + # augment prompt with demonstration samples + prompt += "".join( + [ + f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_prompt_output_for_datum(datum)}" + for datum in self.demonstration_samples + ] + ) + for datum in dataset_iterator: result, usage = self._evaluate_sample(prompt, datum) results.append(result) diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py index c2c2cfa3429d17ef2daf13d9d5ae3b2a1665cd0e..d1e1016be0d840f64ab59f0f6912536b3d422eac 100644 --- a/evoprompt/task/text_classification.py +++ b/evoprompt/task/text_classification.py @@ -1,8 +1,9 @@ from abc import abstractmethod -from functools import lru_cache import logging -from typing import Mapping +import re +from typing import Any, Mapping +from datasets import Dataset from llama_cpp import LlamaGrammar from evoprompt.task import Task from evoprompt.task.task import DatasetDatum @@ -16,24 +17,42 @@ class TextClassification(Task): gold_label = self._get_gold_label_for_datum(datum) class_mapping = self._get_label_mapping() response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() if self.use_grammar: # model output is from label space answer_label = class_mapping[response] else: - answer_label = None - for label in class_mapping.keys(): - if label in response: - answer_label = class_mapping[label] - break + matches = re.findall( + # regex that matches "negative" or "positive" after "Response: " + rf"Response: ({'|'.join(class_mapping.keys())})", + response, + flags=re.IGNORECASE, + ) + if matches: + answer_label = class_mapping[matches[-1]] else: + # TODO in this case we could try other stuff, like checking if a class label is somewhere in the response? logger.warning(f"Invalid answer: {response}") return "failed", usage classification_result = "incorrect" if answer_label != gold_label else "correct" return classification_result, usage - # @lru_cache + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> list[Any]: + # we need to return row indices hence we add them first as a new column to keep track of them + dataset_with_row_indices = dataset.map( + lambda _, idx: {"idx": idx}, with_indices=True + ).shuffle(42) + sample_ids = [] + for label in self._get_label_mapping().values(): + sample_ids_for_label = dataset_with_row_indices.filter( + lambda sample: self._get_gold_label_for_datum(sample) == label + )[:n_evaluation_demo]["idx"] + sample_ids += sample_ids_for_label + return sample_ids + + # NOTE cannot be cached since grammar is not picklable def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): return LlamaGrammar.from_string( "root ::= ({})".format( @@ -43,6 +62,7 @@ class TextClassification(Task): ) def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + # TODO do we need quotes? return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' @abstractmethod @@ -57,6 +77,10 @@ class TextClassification(Task): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: pass + def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str: + id_to_label = {v: k for k, v in self._get_label_mapping().items()} + return id_to_label[self._get_gold_label_for_datum(datum)] + def _aggregate_result(self, results: list[str]) -> float: num_correct_results = sum(1 for result in results if result == "correct") accuracy = num_correct_results / len(results)