From 4030ec32088275d12a04aa5e8b394296fed0ee9b Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Mon, 19 Aug 2024 17:33:58 +0200 Subject: [PATCH] Refactor text generation tasks and implement demonstration samples for missing tasks --- evoprompt/task/__init__.py | 3 +- evoprompt/task/question_answering.py | 38 +++++++++++++--- evoprompt/task/sentiment_analysis.py | 23 +++++----- evoprompt/task/simplification.py | 37 +++------------- evoprompt/task/subjectivity_classification.py | 9 ++-- evoprompt/task/summarization.py | 40 +++-------------- evoprompt/task/task.py | 9 ++-- evoprompt/task/text_classification.py | 31 +++++++++---- evoprompt/task/text_generation.py | 43 +++++++++++++++++++ evoprompt/task/topic_classification.py | 10 ++--- 10 files changed, 136 insertions(+), 107 deletions(-) create mode 100644 evoprompt/task/text_generation.py diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index 0bfbdf2..c83a698 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification from evoprompt.task.sentiment_analysis import SentimentAnalysis from evoprompt.task.topic_classification import AGNews, TREC from evoprompt.task.subjectivity_classification import Subj +from evoprompt.task.text_generation import TextGeneration from evoprompt.task.summarization import Summarization, SAMSum from evoprompt.task.simplification import Simplification, ASSET @@ -33,7 +34,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument( - "--task", "-t", type=str, required=True, choices=tasks.keys() + "--task", "-t", type=str, required=True, choices=sorted(tasks.keys()) ) argument_group.add_argument("--use-grammar", "-g", action="store_true") argument_group.add_argument( diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index f85d441..fbd2d60 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -1,7 +1,8 @@ import logging import re from abc import abstractmethod -from functools import lru_cache +from functools import cache, lru_cache +from typing import Iterable from datasets import Dataset from evaluate import load as load_metric @@ -9,6 +10,7 @@ from llama_cpp import LlamaGrammar from evoprompt.opt_types import ModelUsage from evoprompt.task.task import DatasetDatum, Task +from evoprompt.utils import get_rng logger = logging.getLogger(__name__) @@ -53,7 +55,7 @@ class QuestionAnswering(Task): self.metric = load_metric("squad") - @lru_cache + @cache def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: # context-sensitive grammar context = self._get_context_from_datum(datum) @@ -74,6 +76,12 @@ class QuestionAnswering(Task): "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"' ) + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> Iterable[int]: + # select demonstration samples uniformly at random + return get_rng().choice(len(dataset), n_evaluation_demo, replace=False) + @abstractmethod def _get_id_from_datum(self, datum: DatasetDatum): pass @@ -89,8 +97,23 @@ class QuestionAnswering(Task): def _evaluate_sample(self, prompt: str, datum: DatasetDatum): _id = self._get_id_from_datum(datum) gold_answers = self._get_gold_label_for_datum(datum) - answer, usage = self.predict(prompt, datum) - # TODO check if answer is lower-cased in metric computation + response, usage = self.predict(prompt, datum) + response = response.lower() + + if not self.use_grammar: + # if we do not use a grammar, we need to extract the answer from the response + # otherwise the answer is from the context as enforced by the grammar + matches = re.findall( + # regex that matches class labels after "Response: " + rf"(?:Response:\s?)?(.+)", + response.splitlines()[-1], + flags=re.IGNORECASE, + ) + # look for an answer in the response, if not found, use whole response + if matches: + answer = matches[-1] + else: + answer = response result = self.metric.compute( predictions=[{"prediction_text": answer, "id": _id}], @@ -140,7 +163,9 @@ class QuestionAnswering(Task): @property def base_prompts(self): # TODO find good base prompts - return ["""In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""] + return [ + """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + ] class SQuAD(QuestionAnswering): @@ -165,3 +190,6 @@ class SQuAD(QuestionAnswering): def _get_gold_label_for_datum(self, datum: DatasetDatum): return datum["answers"] + + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + return self._get_gold_label_for_datum(datum)["text"][0] diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index d3697c2..6886ed1 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -1,16 +1,11 @@ -import json import logging -from abc import abstractmethod -from argparse import Namespace -from functools import lru_cache -from pathlib import Path +from functools import cache from typing import Mapping from datasets import load_dataset -from llama_cpp import LlamaGrammar from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.task import Task, TextClassification +from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum logger = logging.getLogger(__name__) @@ -22,13 +17,14 @@ logger = logging.getLogger(__name__) class SentimentAnalysis(TextClassification): - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: return {"negative": 0, "positive": 1} class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): - shorthand = "hf-sst2" + shorthand = "sst2-hf" base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json" def load_validation_set( @@ -111,14 +107,15 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: classes = ["terrible", "bad", "okay", "good", "great"] return dict(zip(classes, range(len(classes)))) class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis): - shorthand = "hf-mr" + shorthand = "mr-hf" base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json" def load_validation_set( diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 8295db7..4d58aeb 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -1,52 +1,27 @@ -import json import logging -from abc import abstractmethod -from functools import lru_cache -from pathlib import Path -from typing import Mapping from evaluate import load as load_metric -from llama_cpp import LlamaGrammar from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.models import LLMModel -from evoprompt.task import Task +from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum logger = logging.getLogger(__name__) -class Simplification(Task): +class Simplification(TextGeneration): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.metric = load_metric("evaluate-metric/sari") - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + def compute_metric(self, datum: DatasetDatum, prediction: str) -> float: gold_label = self._get_gold_label_for_datum(datum) - response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() - - scores = self.metric.compute( + return self.metric.compute( sources=[self._get_text_for_datum(datum)], - predictions=[response], + predictions=[prediction], references=[gold_label], - ) - return scores["sari"], usage - - @lru_cache - def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): - return None - - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' - - @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - pass - - def _aggregate_result(self, results: list[str]) -> float: - return sum(results) / len(results) + )["sari"] @property def metric_name(self): diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 63988c6..5ded8da 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -1,6 +1,4 @@ -import json -from functools import lru_cache -from pathlib import Path +from functools import cache from typing import Mapping from datasets import load_dataset @@ -39,7 +37,8 @@ class Subj(BasePromptsFromJsonMixin, TextClassification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: classes = ["subjective", "objective"] return dict(zip(classes, range(len(classes)))) diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index d199937..3d45aef 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -1,53 +1,25 @@ -import json import logging -from abc import abstractmethod -from functools import lru_cache -from pathlib import Path -from typing import Mapping from evaluate import load as load_metric -from llama_cpp import LlamaGrammar from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.models import LLMModel -from evoprompt.task import Task +from evoprompt.task import TextGeneration from evoprompt.task.task import DatasetDatum logger = logging.getLogger(__name__) -class Summarization(Task): +class Summarization(TextGeneration): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.metric = load_metric("evaluate-metric/rouge") - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + def compute_metric(self, datum: DatasetDatum, prediction: str) -> float: gold_label = self._get_gold_label_for_datum(datum) - response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() - - scores = self.metric.compute(predictions=[response], references=[gold_label]) - - return scores["rougeL"], usage - - @lru_cache - def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): - return None - - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' - - @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - pass - - @abstractmethod - def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: - pass - - def _aggregate_result(self, results: list[str]) -> float: - return sum(results) / len(results) + return self.metric.compute(predictions=[prediction], references=[gold_label])[ + "rougeL" + ] @property def metric_name(self): diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 5bb2f3c..fc91fc5 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod from collections import deque from dataclasses import KW_ONLY, dataclass from statistics import mean -from typing import Any, Iterable, Literal +from typing import Iterable, Literal from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar @@ -334,7 +334,7 @@ class Task(metaclass=ABCMeta): @abstractmethod def _get_demonstration_sample_ids( self, dataset: Dataset, n_evaluation_demo: int - ) -> list[Any]: + ) -> Iterable[int]: pass def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: @@ -342,6 +342,7 @@ class Task(metaclass=ABCMeta): # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, _, usage = self.model( system_message=SYSTEM_MESSAGE, + # TODO Allow to modify prompt construction in subclasses prompt=prompt, prompt_suffix="\n" + self._get_prompt_text_for_datum(datum), prompt_appendix="\nResponse: ", @@ -375,7 +376,7 @@ class Task(metaclass=ABCMeta): @abstractmethod # This method is needed for the demonstration examples. - def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str: + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: pass @abstractmethod @@ -399,7 +400,7 @@ class Task(metaclass=ABCMeta): # augment prompt with demonstration samples prompt += "".join( [ - f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_prompt_output_for_datum(datum)}" + f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_gold_label_generation_for_datum(datum)}" for datum in self.demonstration_samples ] ) diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py index d1e1016..ea7f382 100644 --- a/evoprompt/task/text_classification.py +++ b/evoprompt/task/text_classification.py @@ -1,14 +1,15 @@ -from abc import abstractmethod import logging import re +from abc import abstractmethod +from functools import lru_cache from typing import Any, Mapping from datasets import Dataset from llama_cpp import LlamaGrammar + from evoprompt.task import Task from evoprompt.task.task import DatasetDatum - logger = logging.getLogger(__name__) @@ -17,12 +18,13 @@ class TextClassification(Task): gold_label = self._get_gold_label_for_datum(datum) class_mapping = self._get_label_mapping() response, usage = self.predict(prompt=prompt, datum=datum) + response = response.lower() if self.use_grammar: # model output is from label space answer_label = class_mapping[response] else: matches = re.findall( - # regex that matches "negative" or "positive" after "Response: " + # regex that matches class labels after "Response: " rf"Response: ({'|'.join(class_mapping.keys())})", response, flags=re.IGNORECASE, @@ -30,9 +32,18 @@ class TextClassification(Task): if matches: answer_label = class_mapping[matches[-1]] else: - # TODO in this case we could try other stuff, like checking if a class label is somewhere in the response? - logger.warning(f"Invalid answer: {response}") - return "failed", usage + # look for a label in the response, if not found, return failed + matches = re.findall( + # regex that matches class labels anywhere in the response + rf"({'|'.join(class_mapping.keys())})", + response, + flags=re.IGNORECASE, + ) + if matches: + answer_label = class_mapping[matches[-1]] + else: + logger.warning(f"Invalid answer: {response}") + return "failed", usage classification_result = "incorrect" if answer_label != gold_label else "correct" return classification_result, usage @@ -77,8 +88,12 @@ class TextClassification(Task): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: pass - def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str: - id_to_label = {v: k for k, v in self._get_label_mapping().items()} + @lru_cache + def _get_inverse_label_mapping(self): + return {v: k for k, v in self._get_label_mapping().items()} + + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + id_to_label = self._get_inverse_label_mapping() return id_to_label[self._get_gold_label_for_datum(datum)] def _aggregate_result(self, results: list[str]) -> float: diff --git a/evoprompt/task/text_generation.py b/evoprompt/task/text_generation.py new file mode 100644 index 0000000..7e20cd6 --- /dev/null +++ b/evoprompt/task/text_generation.py @@ -0,0 +1,43 @@ +import logging +from abc import abstractmethod +from typing import Iterable + +from datasets import Dataset +from llama_cpp import LlamaGrammar + +from evoprompt.task import Task +from evoprompt.task.task import DatasetDatum +from evoprompt.utils import get_rng + +logger = logging.getLogger(__name__) + + +class TextGeneration(Task): + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + response, usage = self.predict(prompt=prompt, datum=datum) + response = response.lower() + return self.compute_metric(datum, response), usage + + def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: + # there is no grammar for open text generation + return None + + def _get_demonstration_sample_ids( + self, dataset: Dataset, n_evaluation_demo: int + ) -> Iterable[int]: + # select demonstration samples uniformly at random + return get_rng().choice(len(dataset), n_evaluation_demo, replace=False) + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + # TODO do we need quotes? + return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' + + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: + return self._get_gold_label_for_datum(datum) + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + pass + + def _aggregate_result(self, results: list[str]) -> float: + return sum(results) / len(results) diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index aecec3e..4103a43 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -1,12 +1,9 @@ -import json -from functools import lru_cache -from pathlib import Path +from functools import cache, lru_cache from typing import Mapping from datasets import load_dataset from evoprompt.helpers.prompts import BasePromptsFromJsonMixin -from evoprompt.models import LLMModel from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum @@ -31,8 +28,9 @@ class AGNews(BasePromptsFromJsonMixin, TextClassification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] - @lru_cache - def _get_label_mapping(self) -> Mapping: + @staticmethod + @cache + def _get_label_mapping() -> Mapping: classes = ["world", "sports", "business", "tech"] return dict(zip(classes, range(len(classes)))) -- GitLab