From caf6df98625327c32203df21d159ee895dfa75da Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Tue, 30 Jul 2024 18:01:28 +0200
Subject: [PATCH] Update task structure

---
 evoprompt/task/question_answering.py | 56 +++++++++++++++++++++-------
 evoprompt/task/sentiment_analysis.py | 43 ++++++++++++---------
 evoprompt/task/task.py               |  6 ++-
 3 files changed, 72 insertions(+), 33 deletions(-)

diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py
index 25c9aac..853b6b4 100644
--- a/evoprompt/task/question_answering.py
+++ b/evoprompt/task/question_answering.py
@@ -1,3 +1,4 @@
+from abc import abstractmethod
 import logging
 import re
 from argparse import Namespace
@@ -52,12 +53,12 @@ class QuestionAnswering(Task):
 
         self.metric = load_metric("squad")
 
-    def predict(self, prompt: str, question: str, context: str):
+    def predict(self, prompt: str, datum: DatasetDatum):
         # run model for inference
         grammar = None
         if self.use_grammar:
             # context-sensitive grammar
-            context = context
+            context = self._get_context_from_datum(datum)
             try:
                 grammar = extractive_qa_grammar_fn(context)
             except Exception as e:
@@ -71,7 +72,7 @@ class QuestionAnswering(Task):
         response, usage = self.model(
             system_message=SYSTEM_MESSAGE,
             prompt=prompt,
-            prompt_appendix=self._get_text_for_datum(question, context),
+            prompt_appendix=self._get_prompt_text_for_datum(datum),
             grammar=grammar,
         )
 
@@ -81,21 +82,41 @@ class QuestionAnswering(Task):
 
         return response, usage
 
-    def _get_text_for_datum(self, question: str, context: str) -> str:
+    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        context = self._get_context_from_datum(datum)
+        question = self._get_question_from_datum(datum)
         return (
             "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"'
         )
 
-    def _evaluate_qa_sample(
-        self, prompt: str, id: str, question, context, gold_answers: list[str]
-    ):
-        answer, usage = self.predict(prompt, question, context)
+    @abstractmethod
+    def _get_id_from_datum(self, datum: DatasetDatum):
+        pass
+
+    @abstractmethod
+    def _get_context_from_datum(self, datum: DatasetDatum):
+        pass
+
+    @abstractmethod
+    def _get_question_from_datum(self, datum: DatasetDatum):
+        pass
+
+    @abstractmethod
+    def _get_gold_labels_from_datum(self, datum: DatasetDatum):
+        pass
+
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        _id = self._get_id_from_datum(datum)
+        context = self._get_context_from_datum(datum)
+        question = self._get_question_from_datum(datum)
+        gold_answers = self._get_gold_labels_from_datum(datum)
+        answer, usage = self.predict(prompt, datum)
         # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}')
         # TODO check if answer is lower-cased in metric computation
 
         result = self.metric.compute(
-            predictions=[{"prediction_text": answer, "id": id}],
-            references=[{"answers": gold_answers, "id": id}],
+            predictions=[{"prediction_text": answer, "id": _id}],
+            references=[{"answers": gold_answers, "id": _id}],
         )
         return result["f1"] / 100, usage
 
@@ -154,7 +175,14 @@ class SQuAD(QuestionAnswering):
     def load_test_set(self, test_dataset: str, test_split: str | None):
         return super().load_test_set("squad", "validation")
 
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
-        return self._evaluate_qa_sample(
-            prompt, datum["id"], datum["question"], datum["context"], datum["answers"]
-        )
+    def _get_context_from_datum(self, datum: DatasetDatum):
+        return datum["context"]
+
+    def _get_question_from_datum(self, datum: DatasetDatum):
+        return datum["question"]
+
+    def _get_id_from_datum(self, datum: DatasetDatum):
+        return datum["id"]
+
+    def _get_gold_labels_from_datum(self, datum: DatasetDatum):
+        return datum["answers"]
diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
index a31a4fc..9e1586c 100644
--- a/evoprompt/task/sentiment_analysis.py
+++ b/evoprompt/task/sentiment_analysis.py
@@ -1,3 +1,4 @@
+from abc import abstractmethod
 import logging
 from argparse import Namespace
 from functools import lru_cache
@@ -23,15 +24,13 @@ def sa_grammar_fn(verbose: bool = False):
 
 
 class SentimentAnalysis(Task):
-    shorthand = "sa"
-
-    def predict(self, prompt: str, text: str):
+    def predict(self, prompt: str, datum: DatasetDatum):
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
         response, usage = self.model(
             system_message=SYSTEM_MESSAGE,
             prompt=prompt,
-            prompt_appendix="\nInput: " + '"' + text + '"',
+            prompt_appendix=self._get_prompt_text_for_datum(datum),
             grammar=sa_grammar_fn() if self.use_grammar else None,
         )
 
@@ -42,15 +41,10 @@ class SentimentAnalysis(Task):
 
         return response, usage
 
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        return datum["text"]
-
     def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        gold_label = self._get_gold_label_for_datum(datum)
         sst2_labels = {"negative": 0, "positive": 1}
-
-        response, usage = self.predict(
-            prompt=prompt, text=self._get_text_for_datum(datum)
-        )
+        response, usage = self.predict(prompt=prompt, datum=datum)
         response = response.lower()
         if self.use_grammar:
             # model output is from label space
@@ -65,10 +59,19 @@ class SentimentAnalysis(Task):
                 logger.warning(f"Invalid answer: {response}")
                 return "failed", usage
 
-        classification_result = (
-            "incorrect" if answer_label != datum["label"] else "correct"
-        )
+        classification_result = "incorrect" if answer_label != gold_label else "correct"
         return classification_result, usage
+    
+    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
+    
+    @abstractmethod
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        pass
+
+    @abstractmethod
+    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+        pass
 
     def _aggregate_result(self, results: list[str]) -> float:
         num_correct_results = sum(1 for result in results if result == "correct")
@@ -86,12 +89,18 @@ class SentimentAnalysis(Task):
 
 
 class SST2(SentimentAnalysis):
-    shorthand = "sa"
+    shorthand = "sst2"
 
     def load_validation_set(
         self, validation_dataset: str | None, validation_split: str | None
     ):
-        return super().load_validation_set("SetFit/sst2", "validation[:200]")
+        return super().load_validation_set("stanfordnlp/sst2", "validation[:200]")
 
     def load_test_set(self, test_dataset: str, test_split: str | None):
-        return super().load_test_set("SetFit/sst2", "test")
+        return super().load_test_set("stanfordnlp/sst2", "test")
+
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        return datum["sentence"]
+
+    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
+        return datum["label"]
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 40c3bd0..234dc25 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -171,7 +171,7 @@ class Task(metaclass=ABCMeta):
         pass
 
     @abstractmethod
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
     @abstractmethod
@@ -199,7 +199,9 @@ class Task(metaclass=ABCMeta):
 
         results: list = []
         if self.evaluate_shortest_first:
-            dataset = sorted(dataset, key=lambda x: len(self._get_text_for_datum(x)))
+            dataset = sorted(
+                dataset, key=lambda x: len(self._get_prompt_text_for_datum(x))
+            )
         dataset_iterator: tqdm[DatasetDatum] = tqdm(
             dataset, desc="evaluating prompt", leave=False
         )
-- 
GitLab