From 4030ec32088275d12a04aa5e8b394296fed0ee9b Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Mon, 19 Aug 2024 17:33:58 +0200
Subject: [PATCH] Refactor text generation tasks and implement demonstration
 samples for missing tasks

---
 evoprompt/task/__init__.py                    |  3 +-
 evoprompt/task/question_answering.py          | 38 +++++++++++++---
 evoprompt/task/sentiment_analysis.py          | 23 +++++-----
 evoprompt/task/simplification.py              | 37 +++-------------
 evoprompt/task/subjectivity_classification.py |  9 ++--
 evoprompt/task/summarization.py               | 40 +++--------------
 evoprompt/task/task.py                        |  9 ++--
 evoprompt/task/text_classification.py         | 31 +++++++++----
 evoprompt/task/text_generation.py             | 43 +++++++++++++++++++
 evoprompt/task/topic_classification.py        | 10 ++---
 10 files changed, 136 insertions(+), 107 deletions(-)
 create mode 100644 evoprompt/task/text_generation.py

diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py
index 0bfbdf2..c83a698 100644
--- a/evoprompt/task/__init__.py
+++ b/evoprompt/task/__init__.py
@@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification
 from evoprompt.task.sentiment_analysis import SentimentAnalysis
 from evoprompt.task.topic_classification import AGNews, TREC
 from evoprompt.task.subjectivity_classification import Subj
+from evoprompt.task.text_generation import TextGeneration
 from evoprompt.task.summarization import Summarization, SAMSum
 from evoprompt.task.simplification import Simplification, ASSET
 
@@ -33,7 +34,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options):
 argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
 argument_group = argument_parser.add_argument_group("Task arguments")
 argument_group.add_argument(
-    "--task", "-t", type=str, required=True, choices=tasks.keys()
+    "--task", "-t", type=str, required=True, choices=sorted(tasks.keys())
 )
 argument_group.add_argument("--use-grammar", "-g", action="store_true")
 argument_group.add_argument(
diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py
index f85d441..fbd2d60 100644
--- a/evoprompt/task/question_answering.py
+++ b/evoprompt/task/question_answering.py
@@ -1,7 +1,8 @@
 import logging
 import re
 from abc import abstractmethod
-from functools import lru_cache
+from functools import cache, lru_cache
+from typing import Iterable
 
 from datasets import Dataset
 from evaluate import load as load_metric
@@ -9,6 +10,7 @@ from llama_cpp import LlamaGrammar
 
 from evoprompt.opt_types import ModelUsage
 from evoprompt.task.task import DatasetDatum, Task
+from evoprompt.utils import get_rng
 
 logger = logging.getLogger(__name__)
 
@@ -53,7 +55,7 @@ class QuestionAnswering(Task):
 
         self.metric = load_metric("squad")
 
-    @lru_cache
+    @cache
     def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
         # context-sensitive grammar
         context = self._get_context_from_datum(datum)
@@ -74,6 +76,12 @@ class QuestionAnswering(Task):
             "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"'
         )
 
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> Iterable[int]:
+        # select demonstration samples uniformly at random
+        return get_rng().choice(len(dataset), n_evaluation_demo, replace=False)
+
     @abstractmethod
     def _get_id_from_datum(self, datum: DatasetDatum):
         pass
@@ -89,8 +97,23 @@ class QuestionAnswering(Task):
     def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
         _id = self._get_id_from_datum(datum)
         gold_answers = self._get_gold_label_for_datum(datum)
-        answer, usage = self.predict(prompt, datum)
-        # TODO check if answer is lower-cased in metric computation
+        response, usage = self.predict(prompt, datum)
+        response = response.lower()
+
+        if not self.use_grammar:
+            # if we do not use a grammar, we need to extract the answer from the response
+            # otherwise the answer is from the context as enforced by the grammar
+            matches = re.findall(
+                # regex that matches class labels after "Response: "
+                rf"(?:Response:\s?)?(.+)",
+                response.splitlines()[-1],
+                flags=re.IGNORECASE,
+            )
+            # look for an answer in the response, if not found, use whole response
+            if matches:
+                answer = matches[-1]
+            else:
+                answer = response
 
         result = self.metric.compute(
             predictions=[{"prediction_text": answer, "id": _id}],
@@ -140,7 +163,9 @@ class QuestionAnswering(Task):
     @property
     def base_prompts(self):
         # TODO find good base prompts
-        return ["""In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""]
+        return [
+            """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
+        ]
 
 
 class SQuAD(QuestionAnswering):
@@ -165,3 +190,6 @@ class SQuAD(QuestionAnswering):
 
     def _get_gold_label_for_datum(self, datum: DatasetDatum):
         return datum["answers"]
+
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        return self._get_gold_label_for_datum(datum)["text"][0]
diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
index d3697c2..6886ed1 100644
--- a/evoprompt/task/sentiment_analysis.py
+++ b/evoprompt/task/sentiment_analysis.py
@@ -1,16 +1,11 @@
-import json
 import logging
-from abc import abstractmethod
-from argparse import Namespace
-from functools import lru_cache
-from pathlib import Path
+from functools import cache
 from typing import Mapping
 
 from datasets import load_dataset
-from llama_cpp import LlamaGrammar
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.task import Task, TextClassification
+from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
 logger = logging.getLogger(__name__)
@@ -22,13 +17,14 @@ logger = logging.getLogger(__name__)
 
 
 class SentimentAnalysis(TextClassification):
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         return {"negative": 0, "positive": 1}
 
 
 class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis):
-    shorthand = "hf-sst2"
+    shorthand = "sst2-hf"
     base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json"
 
     def load_validation_set(
@@ -111,14 +107,15 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         return datum["label"]
 
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         classes = ["terrible", "bad", "okay", "good", "great"]
         return dict(zip(classes, range(len(classes))))
 
 
 class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
-    shorthand = "hf-mr"
+    shorthand = "mr-hf"
     base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json"
 
     def load_validation_set(
diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py
index 8295db7..4d58aeb 100644
--- a/evoprompt/task/simplification.py
+++ b/evoprompt/task/simplification.py
@@ -1,52 +1,27 @@
-import json
 import logging
-from abc import abstractmethod
-from functools import lru_cache
-from pathlib import Path
-from typing import Mapping
 
 from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.models import LLMModel
-from evoprompt.task import Task
+from evoprompt.task import TextGeneration
 from evoprompt.task.task import DatasetDatum
 
 logger = logging.getLogger(__name__)
 
 
-class Simplification(Task):
+class Simplification(TextGeneration):
     def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
 
         self.metric = load_metric("evaluate-metric/sari")
 
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+    def compute_metric(self, datum: DatasetDatum, prediction: str) -> float:
         gold_label = self._get_gold_label_for_datum(datum)
-        response, usage = self.predict(prompt=prompt, datum=datum)
-        response = response.lower()
-
-        scores = self.metric.compute(
+        return self.metric.compute(
             sources=[self._get_text_for_datum(datum)],
-            predictions=[response],
+            predictions=[prediction],
             references=[gold_label],
-        )
-        return scores["sari"], usage
-
-    @lru_cache
-    def _get_grammar(self, datum: DatasetDatum, verbose: bool = False):
-        return None
-
-    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
-        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
-
-    @abstractmethod
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    def _aggregate_result(self, results: list[str]) -> float:
-        return sum(results) / len(results)
+        )["sari"]
 
     @property
     def metric_name(self):
diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py
index 63988c6..5ded8da 100644
--- a/evoprompt/task/subjectivity_classification.py
+++ b/evoprompt/task/subjectivity_classification.py
@@ -1,6 +1,4 @@
-import json
-from functools import lru_cache
-from pathlib import Path
+from functools import cache
 from typing import Mapping
 
 from datasets import load_dataset
@@ -39,7 +37,8 @@ class Subj(BasePromptsFromJsonMixin, TextClassification):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         return datum["label"]
 
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         classes = ["subjective", "objective"]
         return dict(zip(classes, range(len(classes))))
diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py
index d199937..3d45aef 100644
--- a/evoprompt/task/summarization.py
+++ b/evoprompt/task/summarization.py
@@ -1,53 +1,25 @@
-import json
 import logging
-from abc import abstractmethod
-from functools import lru_cache
-from pathlib import Path
-from typing import Mapping
 
 from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.models import LLMModel
-from evoprompt.task import Task
+from evoprompt.task import TextGeneration
 from evoprompt.task.task import DatasetDatum
 
 logger = logging.getLogger(__name__)
 
 
-class Summarization(Task):
+class Summarization(TextGeneration):
     def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
 
         self.metric = load_metric("evaluate-metric/rouge")
 
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+    def compute_metric(self, datum: DatasetDatum, prediction: str) -> float:
         gold_label = self._get_gold_label_for_datum(datum)
-        response, usage = self.predict(prompt=prompt, datum=datum)
-        response = response.lower()
-
-        scores = self.metric.compute(predictions=[response], references=[gold_label])
-
-        return scores["rougeL"], usage
-
-    @lru_cache
-    def _get_grammar(self, datum: DatasetDatum, verbose: bool = False):
-        return None
-
-    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
-        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
-
-    @abstractmethod
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    @abstractmethod
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    def _aggregate_result(self, results: list[str]) -> float:
-        return sum(results) / len(results)
+        return self.metric.compute(predictions=[prediction], references=[gold_label])[
+            "rougeL"
+        ]
 
     @property
     def metric_name(self):
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 5bb2f3c..fc91fc5 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod
 from collections import deque
 from dataclasses import KW_ONLY, dataclass
 from statistics import mean
-from typing import Any, Iterable, Literal
+from typing import Iterable, Literal
 
 from datasets import Dataset, load_dataset
 from llama_cpp import LlamaGrammar
@@ -334,7 +334,7 @@ class Task(metaclass=ABCMeta):
     @abstractmethod
     def _get_demonstration_sample_ids(
         self, dataset: Dataset, n_evaluation_demo: int
-    ) -> list[Any]:
+    ) -> Iterable[int]:
         pass
 
     def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
@@ -342,6 +342,7 @@ class Task(metaclass=ABCMeta):
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
         response, _, usage = self.model(
             system_message=SYSTEM_MESSAGE,
+            # TODO Allow to modify prompt construction in subclasses
             prompt=prompt,
             prompt_suffix="\n" + self._get_prompt_text_for_datum(datum),
             prompt_appendix="\nResponse: ",
@@ -375,7 +376,7 @@ class Task(metaclass=ABCMeta):
 
     @abstractmethod
     # This method is needed for the demonstration examples.
-    def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str:
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
     @abstractmethod
@@ -399,7 +400,7 @@ class Task(metaclass=ABCMeta):
         # augment prompt with demonstration samples
         prompt += "".join(
             [
-                f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_prompt_output_for_datum(datum)}"
+                f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_gold_label_generation_for_datum(datum)}"
                 for datum in self.demonstration_samples
             ]
         )
diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py
index d1e1016..ea7f382 100644
--- a/evoprompt/task/text_classification.py
+++ b/evoprompt/task/text_classification.py
@@ -1,14 +1,15 @@
-from abc import abstractmethod
 import logging
 import re
+from abc import abstractmethod
+from functools import lru_cache
 from typing import Any, Mapping
 
 from datasets import Dataset
 from llama_cpp import LlamaGrammar
+
 from evoprompt.task import Task
 from evoprompt.task.task import DatasetDatum
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -17,12 +18,13 @@ class TextClassification(Task):
         gold_label = self._get_gold_label_for_datum(datum)
         class_mapping = self._get_label_mapping()
         response, usage = self.predict(prompt=prompt, datum=datum)
+        response = response.lower()
         if self.use_grammar:
             # model output is from label space
             answer_label = class_mapping[response]
         else:
             matches = re.findall(
-                # regex that matches "negative" or "positive" after "Response: "
+                # regex that matches class labels after "Response: "
                 rf"Response: ({'|'.join(class_mapping.keys())})",
                 response,
                 flags=re.IGNORECASE,
@@ -30,9 +32,18 @@ class TextClassification(Task):
             if matches:
                 answer_label = class_mapping[matches[-1]]
             else:
-                # TODO in this case we could try other stuff, like checking if a class label is somewhere in the response?
-                logger.warning(f"Invalid answer: {response}")
-                return "failed", usage
+                # look for a label in the response, if not found, return failed
+                matches = re.findall(
+                    # regex that matches class labels anywhere in the response
+                    rf"({'|'.join(class_mapping.keys())})",
+                    response,
+                    flags=re.IGNORECASE,
+                )
+                if matches:
+                    answer_label = class_mapping[matches[-1]]
+                else:
+                    logger.warning(f"Invalid answer: {response}")
+                    return "failed", usage
 
         classification_result = "incorrect" if answer_label != gold_label else "correct"
         return classification_result, usage
@@ -77,8 +88,12 @@ class TextClassification(Task):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
-    def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str:
-        id_to_label = {v: k for k, v in self._get_label_mapping().items()}
+    @lru_cache
+    def _get_inverse_label_mapping(self):
+        return {v: k for k, v in self._get_label_mapping().items()}
+
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        id_to_label = self._get_inverse_label_mapping()
         return id_to_label[self._get_gold_label_for_datum(datum)]
 
     def _aggregate_result(self, results: list[str]) -> float:
diff --git a/evoprompt/task/text_generation.py b/evoprompt/task/text_generation.py
new file mode 100644
index 0000000..7e20cd6
--- /dev/null
+++ b/evoprompt/task/text_generation.py
@@ -0,0 +1,43 @@
+import logging
+from abc import abstractmethod
+from typing import Iterable
+
+from datasets import Dataset
+from llama_cpp import LlamaGrammar
+
+from evoprompt.task import Task
+from evoprompt.task.task import DatasetDatum
+from evoprompt.utils import get_rng
+
+logger = logging.getLogger(__name__)
+
+
+class TextGeneration(Task):
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        response, usage = self.predict(prompt=prompt, datum=datum)
+        response = response.lower()
+        return self.compute_metric(datum, response), usage
+
+    def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
+        # there is no grammar for open text generation
+        return None
+
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> Iterable[int]:
+        # select demonstration samples uniformly at random
+        return get_rng().choice(len(dataset), n_evaluation_demo, replace=False)
+
+    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        # TODO do we need quotes?
+        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
+
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        return self._get_gold_label_for_datum(datum)
+
+    @abstractmethod
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        pass
+
+    def _aggregate_result(self, results: list[str]) -> float:
+        return sum(results) / len(results)
diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py
index aecec3e..4103a43 100644
--- a/evoprompt/task/topic_classification.py
+++ b/evoprompt/task/topic_classification.py
@@ -1,12 +1,9 @@
-import json
-from functools import lru_cache
-from pathlib import Path
+from functools import cache, lru_cache
 from typing import Mapping
 
 from datasets import load_dataset
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.models import LLMModel
 from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
@@ -31,8 +28,9 @@ class AGNews(BasePromptsFromJsonMixin, TextClassification):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         return datum["label"]
 
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         classes = ["world", "sports", "business", "tech"]
         return dict(zip(classes, range(len(classes))))
 
-- 
GitLab