Skip to content
Snippets Groups Projects
Commit 36d6d073 authored by Max Kimmich's avatar Max Kimmich
Browse files

Add SAMSum dataset

parent 458e36ee
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification
from evoprompt.task.sentiment_analysis import SentimentAnalysis
from evoprompt.task.topic_classification import AGNews, TREC
from evoprompt.task.subjectivity_classification import Subj
from evoprompt.task.summarization import SAMSum
from evoprompt.utils import get_all_subclasses
......
from abc import abstractmethod
from functools import lru_cache
import logging
from typing import Mapping
from evaluate import load as load_metric
from llama_cpp import LlamaGrammar
from evoprompt.models import LLMModel
from evoprompt.task import Task
from evoprompt.task.task import DatasetDatum
logger = logging.getLogger(__name__)
class Summarization(Task):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.metric = load_metric("evaluate-metric/rouge")
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
gold_label = self._get_gold_label_for_datum(datum)
response, usage = self.predict(prompt=prompt, datum=datum)
response = response.lower()
scores = self.metric.compute(predictions=[response], references=[gold_label])
return scores["rougeL"], usage
@lru_cache
def _get_grammer(self, datum: DatasetDatum, verbose: bool = False):
return None
def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
@abstractmethod
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
pass
@abstractmethod
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
pass
def _aggregate_result(self, results: list[str]) -> float:
return sum(results)/len(results)
@property
def metric_name(self):
return "rougeL"
@property
def base_prompt(self):
return 'Please summarize the main context.'
class SAMSum(Summarization):
shorthand = "sams"
def __init__(self, *args, **kwargs) -> None:
super().__init__(
*args,
validation_dataset="Samsung/samsum",
test_dataset="Samsung/samsum",
validation_split="train",
test_split="test",
**kwargs
)
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
return datum["dialogue"]
def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
return datum["summary"]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment