From caf6df98625327c32203df21d159ee895dfa75da Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 30 Jul 2024 18:01:28 +0200 Subject: [PATCH] Update task structure --- evoprompt/task/question_answering.py | 56 +++++++++++++++++++++------- evoprompt/task/sentiment_analysis.py | 43 ++++++++++++--------- evoprompt/task/task.py | 6 ++- 3 files changed, 72 insertions(+), 33 deletions(-) diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index 25c9aac..853b6b4 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -1,3 +1,4 @@ +from abc import abstractmethod import logging import re from argparse import Namespace @@ -52,12 +53,12 @@ class QuestionAnswering(Task): self.metric = load_metric("squad") - def predict(self, prompt: str, question: str, context: str): + def predict(self, prompt: str, datum: DatasetDatum): # run model for inference grammar = None if self.use_grammar: # context-sensitive grammar - context = context + context = self._get_context_from_datum(datum) try: grammar = extractive_qa_grammar_fn(context) except Exception as e: @@ -71,7 +72,7 @@ class QuestionAnswering(Task): response, usage = self.model( system_message=SYSTEM_MESSAGE, prompt=prompt, - prompt_appendix=self._get_text_for_datum(question, context), + prompt_appendix=self._get_prompt_text_for_datum(datum), grammar=grammar, ) @@ -81,21 +82,41 @@ class QuestionAnswering(Task): return response, usage - def _get_text_for_datum(self, question: str, context: str) -> str: + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + context = self._get_context_from_datum(datum) + question = self._get_question_from_datum(datum) return ( "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"' ) - def _evaluate_qa_sample( - self, prompt: str, id: str, question, context, gold_answers: list[str] - ): - answer, usage = self.predict(prompt, question, context) + @abstractmethod + def _get_id_from_datum(self, datum: DatasetDatum): + pass + + @abstractmethod + def _get_context_from_datum(self, datum: DatasetDatum): + pass + + @abstractmethod + def _get_question_from_datum(self, datum: DatasetDatum): + pass + + @abstractmethod + def _get_gold_labels_from_datum(self, datum: DatasetDatum): + pass + + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + _id = self._get_id_from_datum(datum) + context = self._get_context_from_datum(datum) + question = self._get_question_from_datum(datum) + gold_answers = self._get_gold_labels_from_datum(datum) + answer, usage = self.predict(prompt, datum) # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}') # TODO check if answer is lower-cased in metric computation result = self.metric.compute( - predictions=[{"prediction_text": answer, "id": id}], - references=[{"answers": gold_answers, "id": id}], + predictions=[{"prediction_text": answer, "id": _id}], + references=[{"answers": gold_answers, "id": _id}], ) return result["f1"] / 100, usage @@ -154,7 +175,14 @@ class SQuAD(QuestionAnswering): def load_test_set(self, test_dataset: str, test_split: str | None): return super().load_test_set("squad", "validation") - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - return self._evaluate_qa_sample( - prompt, datum["id"], datum["question"], datum["context"], datum["answers"] - ) + def _get_context_from_datum(self, datum: DatasetDatum): + return datum["context"] + + def _get_question_from_datum(self, datum: DatasetDatum): + return datum["question"] + + def _get_id_from_datum(self, datum: DatasetDatum): + return datum["id"] + + def _get_gold_labels_from_datum(self, datum: DatasetDatum): + return datum["answers"] diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index a31a4fc..9e1586c 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -1,3 +1,4 @@ +from abc import abstractmethod import logging from argparse import Namespace from functools import lru_cache @@ -23,15 +24,13 @@ def sa_grammar_fn(verbose: bool = False): class SentimentAnalysis(Task): - shorthand = "sa" - - def predict(self, prompt: str, text: str): + def predict(self, prompt: str, datum: DatasetDatum): # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, usage = self.model( system_message=SYSTEM_MESSAGE, prompt=prompt, - prompt_appendix="\nInput: " + '"' + text + '"', + prompt_appendix=self._get_prompt_text_for_datum(datum), grammar=sa_grammar_fn() if self.use_grammar else None, ) @@ -42,15 +41,10 @@ class SentimentAnalysis(Task): return response, usage - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - return datum["text"] - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + gold_label = self._get_gold_label_for_datum(datum) sst2_labels = {"negative": 0, "positive": 1} - - response, usage = self.predict( - prompt=prompt, text=self._get_text_for_datum(datum) - ) + response, usage = self.predict(prompt=prompt, datum=datum) response = response.lower() if self.use_grammar: # model output is from label space @@ -65,10 +59,19 @@ class SentimentAnalysis(Task): logger.warning(f"Invalid answer: {response}") return "failed", usage - classification_result = ( - "incorrect" if answer_label != datum["label"] else "correct" - ) + classification_result = "incorrect" if answer_label != gold_label else "correct" return classification_result, usage + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + pass + + @abstractmethod + def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + pass def _aggregate_result(self, results: list[str]) -> float: num_correct_results = sum(1 for result in results if result == "correct") @@ -86,12 +89,18 @@ class SentimentAnalysis(Task): class SST2(SentimentAnalysis): - shorthand = "sa" + shorthand = "sst2" def load_validation_set( self, validation_dataset: str | None, validation_split: str | None ): - return super().load_validation_set("SetFit/sst2", "validation[:200]") + return super().load_validation_set("stanfordnlp/sst2", "validation[:200]") def load_test_set(self, test_dataset: str, test_split: str | None): - return super().load_test_set("SetFit/sst2", "test") + return super().load_test_set("stanfordnlp/sst2", "test") + + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + return datum["sentence"] + + def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + return datum["label"] diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 40c3bd0..234dc25 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -171,7 +171,7 @@ class Task(metaclass=ABCMeta): pass @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: pass @abstractmethod @@ -199,7 +199,9 @@ class Task(metaclass=ABCMeta): results: list = [] if self.evaluate_shortest_first: - dataset = sorted(dataset, key=lambda x: len(self._get_text_for_datum(x))) + dataset = sorted( + dataset, key=lambda x: len(self._get_prompt_text_for_datum(x)) + ) dataset_iterator: tqdm[DatasetDatum] = tqdm( dataset, desc="evaluating prompt", leave=False ) -- GitLab