diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index 0bfbdf273c1cb6a0a5fe9232efc949b637a4f73e..c83a698b57a52a9f05c5ba7d2ba0dfd8dd43233a 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification from evoprompt.task.sentiment_analysis import SentimentAnalysis from evoprompt.task.topic_classification import AGNews, TREC from evoprompt.task.subjectivity_classification import Subj +from evoprompt.task.text_generation import TextGeneration from evoprompt.task.summarization import Summarization, SAMSum from evoprompt.task.simplification import Simplification, ASSET @@ -33,7 +34,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument( - "--task", "-t", type=str, required=True, choices=tasks.keys() + "--task", "-t", type=str, required=True, choices=sorted(tasks.keys()) ) argument_group.add_argument("--use-grammar", "-g", action="store_true") argument_group.add_argument( diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index f85d441957cdaf6281e3541e5bfb3bc431364d90..fbd2d60867e89ee95bb532a7ee2a8d3283de8097 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -1,7 +1,8 @@ import logging import re from abc import abstractmethod -from functools import lru_cache +from functools import cache, lru_cache +from typing import Iterable from datasets import Dataset from evaluate import load as load_metric @@ -9,6 +10,7 @@ from llama_cpp import LlamaGrammar from evoprompt.opt_types import ModelUsage from evoprompt.task.task import DatasetDatum, Task +from evoprompt.utils import get_rng logger = logging.getLogger(__name__) @@ -53,7 +55,7 @@ class QuestionAnswering(Task): self.metric = load_metric("squad") - @lru_cache + @cache def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: # context-sensitive grammar context = self._get_context_from_datum(datum) @@ -74,6 +76,12 @@ class QuestionAnswering(Task): "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"' ) + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> Iterable[int]: + # select demonstration samples uniformly at random + return get_rng().choice(len(dataset), n_evaluation_demo, replace=False) + @abstractmethod def _get_id_from_datum(self, datum: DatasetDatum): pass @@ -89,8 +97,23 @@ class QuestionAnswering(Task): def _evaluate_sample(self, prompt: str, datum: DatasetDatum): _id = self._get_id_from_datum(datum) gold_answers = self._get_gold_label_for_datum(datum) - answer, usage = self.predict(prompt, datum) - # TODO check if answer is lower-cased in metric computation + response, usage = self.predict(prompt, datum) + response = response.lower() + + if not self.use_grammar: + # if we do not use a grammar, we need to extract the answer from the response + # otherwise the answer is from the context as enforced by the grammar + matches = re.findall( + # regex that matches class labels after "Response: " + rf"(?:Response:\s?)?(.+)", + response.splitlines()[-1], + flags=re.IGNORECASE, + ) + # look for an answer in the response, if not found, use whole response + if matches: + answer = matches[-1] + else: + answer = response result = self.metric.compute( predictions=[{"prediction_text": answer, "id": _id}], @@ -140,7 +163,9 @@ class QuestionAnswering(Task): @property def base_prompts(self): # TODO find good base prompts - return ["""In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""] + return [ + """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + ] class SQuAD(QuestionAnswering): @@ -165,3 +190,6 @@ class SQuAD(QuestionAnswering): def _get_gold_label_for_datum(self, datum: DatasetDatum): return datum["answers"] + + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + return self._get_gold_label_for_datum(datum)["text"][0] diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index d3697c20ac2508531afcb39a69ec9da90ca0c71a..6886ed18ce89032b4f19b56f99155abab7f222b0 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -1,16 +1,11 @@ -import json import logging -from abc import abstractmethod -from argparse import Namespace -from functools import lru_cache -from pathlib import Path +from functools import cache from typing import Mapping from datasets import load_dataset -from llama_cpp import LlamaGrammar from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.task import Task, TextClassification +from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum logger = logging.getLogger(__name__) @@ -22,13 +17,14 @@ logger = logging.getLogger(__name__) class SentimentAnalysis(TextClassification): - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: return {"negative": 0, "positive": 1} class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): - shorthand = "hf-sst2" + shorthand = "sst2-hf" base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" def load_validation_set( @@ -111,14 +107,15 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: classes = ["terrible", "bad", "okay", "good", "great"] return dict(zip(classes, range(len(classes)))) class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): - shorthand = "hf-mr" + shorthand = "mr-hf" base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" def load_validation_set( diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 8295db713986895859f400051f268cb023a36331..4d58aebf61d4d9d167a6d623839b674b1c0a95e3 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -1,52 +1,27 @@ -import json import logging -from abc import abstractmethod -from functools import lru_cache -from pathlib import Path -from typing import Mapping from evaluate import load as load_metric -from llama_cpp import LlamaGrammar from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.models import LLMModel -from evoprompt.task import Task +from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum logger = logging.getLogger(__name__) -class Simplification(Task): +class Simplification(TextGeneration): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.metric = load_metric("evaluate-metric/sari") - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + def compute_metric(self, datum: DatasetDatum, prediction: str) -> float: gold_label = self._get_gold_label_for_datum(datum) - response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() - - scores = self.metric.compute( + return self.metric.compute( sources=[self._get_text_for_datum(datum)], - predictions=[response], + predictions=[prediction], references=[gold_label], - ) - return scores["sari"], usage - - @lru_cache - def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): - return None - - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' - - @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - pass - - def _aggregate_result(self, results: list[str]) -> float: - return sum(results) / len(results) + )["sari"] @property def metric_name(self): diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 63988c6aafde7edc293149cedfb9e882e8e2c4c4..5ded8dada78f412afa15f7362a7813b9f32fb487 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -1,6 +1,4 @@ -import json -from functools import lru_cache -from pathlib import Path +from functools import cache from typing import Mapping from datasets import load_dataset @@ -39,7 +37,8 @@ class Subj(BasePromptsFromJsonMixin, TextClassification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: classes = ["subjective", "objective"] return dict(zip(classes, range(len(classes)))) diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index d199937f76f4799d77d6a8a32095b99cc674d6dd..3d45aefdca87d03d30025276cc5a391b596dee74 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -1,53 +1,25 @@ -import json import logging -from abc import abstractmethod -from functools import lru_cache -from pathlib import Path -from typing import Mapping from evaluate import load as load_metric -from llama_cpp import LlamaGrammar from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.models import LLMModel -from evoprompt.task import Task +from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum logger = logging.getLogger(__name__) -class Summarization(Task): +class Summarization(TextGeneration): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.metric = load_metric("evaluate-metric/rouge") - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + def compute_metric(self, datum: DatasetDatum, prediction: str) -> float: gold_label = self._get_gold_label_for_datum(datum) - response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() - - scores = self.metric.compute(predictions=[response], references=[gold_label]) - - return scores["rougeL"], usage - - @lru_cache - def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): - return None - - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' - - @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - pass - - @abstractmethod - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: - pass - - def _aggregate_result(self, results: list[str]) -> float: - return sum(results) / len(results) + return self.metric.compute(predictions=[prediction], references=[gold_label])[ + "rougeL" + ] @property def metric_name(self): diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 8b27328b2139fc50716f132f8b26acbabeed8139..cfc79a208161243ce89ecbf9ad97d95d11091d6a 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -1,11 +1,9 @@ import logging -import shelve from abc import ABCMeta, abstractmethod from collections import deque from dataclasses import KW_ONLY, dataclass -from pathlib import Path from statistics import mean -from typing import Iterable, Literal, Union +from typing import Iterable, Literal from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar @@ -265,6 +263,7 @@ class Task(metaclass=ABCMeta): evaluation_strategy: EvaluationStrategyKey, validation_split: str | None = None, use_evolution_demo: bool = False, + n_evaluation_demo: int | None = None, test_split: str | None = None, debug: bool = False, **kwargs, @@ -272,6 +271,7 @@ class Task(metaclass=ABCMeta): self.model = model self.debug = debug self.use_grammar = use_grammar + self.n_evaluation_demo = n_evaluation_demo self.evaluation_strategy = get_evaluation_strategy(evaluation_strategy)(self) logger.info( @@ -287,6 +287,12 @@ class Task(metaclass=ABCMeta): self.validation_dataset = self.load_validation_set( validation_dataset, validation_split ) + + # get demonstration samples + self.demonstration_samples, self.validation_dataset = ( + self.get_demonstration_samples(self.validation_dataset) + ) + if self.debug and len(self.validation_dataset) > 10: self.validation_dataset = self.validation_dataset.shuffle(42).select( range(10) @@ -309,13 +315,39 @@ class Task(metaclass=ABCMeta): def load_test_set(self, test_dataset: str, test_split: str | None): return load_dataset(test_dataset, split=test_split) + def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]: + if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0: + return [] + + # get demonstration samples from validation set + samples_ids = self._get_demonstration_sample_ids( + dataset, self.n_evaluation_demo + ) + # retrieve demonstration samples from validation set + demonstration_samples = dataset.filter( + lambda _, idx: idx in samples_ids, with_indices=True + ) + # remove demonstration samples from validation set + remaining_dataset = self.dataset.filter( + lambda _, idx: idx not in samples_ids, with_indices=True + ) + return demonstration_samples, remaining_dataset + + @abstractmethod + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> Iterable[int]: + pass + def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, _, usage = self.model.create_completion( system_message=SYSTEM_MESSAGE, + # TODO Allow to modify prompt construction in subclasses prompt=prompt, - prompt_appendix=self._get_prompt_text_for_datum(datum), + prompt_suffix="\n" + self._get_prompt_text_for_datum(datum), + prompt_appendix="\nResponse: ", # grammar can be applied to constrain the model output grammar=self._get_grammar(datum) if self.use_grammar else None, # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result @@ -346,9 +378,10 @@ class Task(metaclass=ABCMeta): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: pass - def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str: - # This method is needed for the demonstration example. - return self._get_gold_label_for_datum(datum) + @abstractmethod + # This method is needed for the demonstration examples. + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + pass @abstractmethod def _aggregate_result(self, results: list) -> float: @@ -368,6 +401,14 @@ class Task(metaclass=ABCMeta): evaluation_usage = ModelUsage() evaluation_history = [] + # augment prompt with demonstration samples + prompt += "".join( + [ + f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_gold_label_generation_for_datum(datum)}" + for datum in self.demonstration_samples + ] + ) + for datum in dataset_iterator: result, usage = self._evaluate_sample(prompt, datum) results.append(result) diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py index c2c2cfa3429d17ef2daf13d9d5ae3b2a1665cd0e..ea7f3821dd62b1a1190a63a5e22ee7ed70388a82 100644 --- a/evoprompt/task/text_classification.py +++ b/evoprompt/task/text_classification.py @@ -1,13 +1,15 @@ +import logging +import re from abc import abstractmethod from functools import lru_cache -import logging -from typing import Mapping +from typing import Any, Mapping +from datasets import Dataset from llama_cpp import LlamaGrammar + from evoprompt.task import Task from evoprompt.task.task import DatasetDatum - logger = logging.getLogger(__name__) @@ -21,19 +23,47 @@ class TextClassification(Task): # model output is from label space answer_label = class_mapping[response] else: - answer_label = None - for label in class_mapping.keys(): - if label in response: - answer_label = class_mapping[label] - break + matches = re.findall( + # regex that matches class labels after "Response: " + rf"Response: ({'|'.join(class_mapping.keys())})", + response, + flags=re.IGNORECASE, + ) + if matches: + answer_label = class_mapping[matches[-1]] else: - logger.warning(f"Invalid answer: {response}") - return "failed", usage + # look for a label in the response, if not found, return failed + matches = re.findall( + # regex that matches class labels anywhere in the response + rf"({'|'.join(class_mapping.keys())})", + response, + flags=re.IGNORECASE, + ) + if matches: + answer_label = class_mapping[matches[-1]] + else: + logger.warning(f"Invalid answer: {response}") + return "failed", usage classification_result = "incorrect" if answer_label != gold_label else "correct" return classification_result, usage - # @lru_cache + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> list[Any]: + # we need to return row indices hence we add them first as a new column to keep track of them + dataset_with_row_indices = dataset.map( + lambda _, idx: {"idx": idx}, with_indices=True + ).shuffle(42) + sample_ids = [] + for label in self._get_label_mapping().values(): + sample_ids_for_label = dataset_with_row_indices.filter( + lambda sample: self._get_gold_label_for_datum(sample) == label + )[:n_evaluation_demo]["idx"] + sample_ids += sample_ids_for_label + return sample_ids + + # NOTE cannot be cached since grammar is not picklable def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): return LlamaGrammar.from_string( "root ::= ({})".format( @@ -43,6 +73,7 @@ class TextClassification(Task): ) def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + # TODO do we need quotes? return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' @abstractmethod @@ -57,6 +88,14 @@ class TextClassification(Task): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: pass + @lru_cache + def _get_inverse_label_mapping(self): + return {v: k for k, v in self._get_label_mapping().items()} + + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + id_to_label = self._get_inverse_label_mapping() + return id_to_label[self._get_gold_label_for_datum(datum)] + def _aggregate_result(self, results: list[str]) -> float: num_correct_results = sum(1 for result in results if result == "correct") accuracy = num_correct_results / len(results) diff --git a/evoprompt/task/text_generation.py b/evoprompt/task/text_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..7e20cd67f5f1ade917b5180cb952d1e379720048 --- /dev/null +++ b/evoprompt/task/text_generation.py @@ -0,0 +1,43 @@ +import logging +from abc import abstractmethod +from typing import Iterable + +from datasets import Dataset +from llama_cpp import LlamaGrammar + +from evoprompt.task import Task +from evoprompt.task.task import DatasetDatum +from evoprompt.utils import get_rng + +logger = logging.getLogger(__name__) + + +class TextGeneration(Task): + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + response, usage = self.predict(prompt=prompt, datum=datum) + response = response.lower() + return self.compute_metric(datum, response), usage + + def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: + # there is no grammar for open text generation + return None + + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> Iterable[int]: + # select demonstration samples uniformly at random + return get_rng().choice(len(dataset), n_evaluation_demo, replace=False) + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + # TODO do we need quotes? + return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' + + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + return self._get_gold_label_for_datum(datum) + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + pass + + def _aggregate_result(self, results: list[str]) -> float: + return sum(results) / len(results) diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index aecec3ee9b544b17b7d990d8feecc010a8aef747..4103a43242c0f9adbf83adb75bf73cac16e19d34 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -1,12 +1,9 @@ -import json -from functools import lru_cache -from pathlib import Path +from functools import cache, lru_cache from typing import Mapping from datasets import load_dataset from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.models import LLMModel from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum @@ -31,8 +28,9 @@ class AGNews(BasePromptsFromJsonMixin, TextClassification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: classes = ["world", "sports", "business", "tech"] return dict(zip(classes, range(len(classes))))