From 36d6d0737b553d3232e41689f0a8b81c3f557b23 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Thu, 1 Aug 2024 15:25:00 +0200 Subject: [PATCH] Add SAMSum dataset --- evoprompt/task/__init__.py | 1 + evoprompt/task/summarization.py | 75 +++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 evoprompt/task/summarization.py diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index 10378c2..75b3e79 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification from evoprompt.task.sentiment_analysis import SentimentAnalysis from evoprompt.task.topic_classification import AGNews, TREC from evoprompt.task.subjectivity_classification import Subj +from evoprompt.task.summarization import SAMSum from evoprompt.utils import get_all_subclasses diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py new file mode 100644 index 0000000..853e48e --- /dev/null +++ b/evoprompt/task/summarization.py @@ -0,0 +1,75 @@ +from abc import abstractmethod +from functools import lru_cache +import logging +from typing import Mapping + +from evaluate import load as load_metric +from llama_cpp import LlamaGrammar +from evoprompt.models import LLMModel +from evoprompt.task import Task +from evoprompt.task.task import DatasetDatum + + +logger = logging.getLogger(__name__) + + +class Summarization(Task): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.metric = load_metric("evaluate-metric/rouge") + + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + gold_label = self._get_gold_label_for_datum(datum) + response, usage = self.predict(prompt=prompt, datum=datum) + response = response.lower() + + scores = self.metric.compute(predictions=[response], references=[gold_label]) + + return scores["rougeL"], usage + + @lru_cache + def _get_grammer(self, datum: DatasetDatum, verbose: bool = False): + return None + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + pass + + @abstractmethod + def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + pass + + def _aggregate_result(self, results: list[str]) -> float: + return sum(results)/len(results) + + @property + def metric_name(self): + return "rougeL" + + @property + def base_prompt(self): + return 'Please summarize the main context.' + + +class SAMSum(Summarization): + shorthand = "sams" + + def __init__(self, *args, **kwargs) -> None: + super().__init__( + *args, + validation_dataset="Samsung/samsum", + test_dataset="Samsung/samsum", + validation_split="train", + test_split="test", + **kwargs + ) + + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + return datum["dialogue"] + + def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + return datum["summary"] \ No newline at end of file -- GitLab