From 36d6d0737b553d3232e41689f0a8b81c3f557b23 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Thu, 1 Aug 2024 15:25:00 +0200
Subject: [PATCH] Add SAMSum dataset

---
 evoprompt/task/__init__.py      |  1 +
 evoprompt/task/summarization.py | 75 +++++++++++++++++++++++++++++++++
 2 files changed, 76 insertions(+)
 create mode 100644 evoprompt/task/summarization.py

diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py
index 10378c2..75b3e79 100644
--- a/evoprompt/task/__init__.py
+++ b/evoprompt/task/__init__.py
@@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification
 from evoprompt.task.sentiment_analysis import SentimentAnalysis
 from evoprompt.task.topic_classification import AGNews, TREC
 from evoprompt.task.subjectivity_classification import Subj
+from evoprompt.task.summarization import SAMSum
 
 from evoprompt.utils import get_all_subclasses
 
diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py
new file mode 100644
index 0000000..853e48e
--- /dev/null
+++ b/evoprompt/task/summarization.py
@@ -0,0 +1,75 @@
+from abc import abstractmethod
+from functools import lru_cache
+import logging
+from typing import Mapping
+
+from evaluate import load as load_metric
+from llama_cpp import LlamaGrammar
+from evoprompt.models import LLMModel
+from evoprompt.task import Task
+from evoprompt.task.task import DatasetDatum
+
+
+logger = logging.getLogger(__name__)
+
+
+class Summarization(Task):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+        self.metric = load_metric("evaluate-metric/rouge")
+
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        gold_label = self._get_gold_label_for_datum(datum)
+        response, usage = self.predict(prompt=prompt, datum=datum)
+        response = response.lower()
+        
+        scores = self.metric.compute(predictions=[response], references=[gold_label])
+
+        return scores["rougeL"], usage
+
+    @lru_cache
+    def _get_grammer(self, datum: DatasetDatum, verbose: bool = False):
+        return None
+
+    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
+
+    @abstractmethod
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        pass
+
+    @abstractmethod
+    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+        pass
+
+    def _aggregate_result(self, results: list[str]) -> float:
+        return sum(results)/len(results)
+
+    @property
+    def metric_name(self):
+        return "rougeL"
+    
+    @property
+    def base_prompt(self):
+        return 'Please summarize the main context.'
+
+
+class SAMSum(Summarization):
+    shorthand = "sams"
+
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(
+            *args,
+            validation_dataset="Samsung/samsum",
+            test_dataset="Samsung/samsum",
+            validation_split="train",
+            test_split="test",
+            **kwargs
+        )
+
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        return datum["dialogue"]
+
+    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+        return datum["summary"]
\ No newline at end of file
-- 
GitLab