From c27d342bcd0f919e50f51870bfc8c0638a520671 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Mon, 29 Jul 2024 17:53:30 +0200
Subject: [PATCH] Refactor tasks

---
 evoprompt/evolution.py               |   2 +-
 evoprompt/task.py                    | 487 ---------------------------
 evoprompt/task/__init__.py           |  33 ++
 evoprompt/task/question_answering.py | 160 +++++++++
 evoprompt/task/sentiment_analysis.py |  97 ++++++
 evoprompt/task/task.py               | 245 ++++++++++++++
 evoprompt/utils.py                   |   5 +
 main.py                              |   2 +-
 8 files changed, 542 insertions(+), 489 deletions(-)
 delete mode 100644 evoprompt/task.py
 create mode 100644 evoprompt/task/__init__.py
 create mode 100644 evoprompt/task/question_answering.py
 create mode 100644 evoprompt/task/sentiment_analysis.py
 create mode 100644 evoprompt/task/task.py

diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py
index 9026d4d..e96e0ee 100644
--- a/evoprompt/evolution.py
+++ b/evoprompt/evolution.py
@@ -10,7 +10,7 @@ from evoprompt.models import LLMModel
 from evoprompt.opt_types import ModelUsage, Prompt
 from evoprompt.optimization import PromptOptimization
 from evoprompt.task import Task
-from evoprompt.utils import log_calls
+from evoprompt.utils import log_calls, get_all_subclasses
 
 logger = logging.getLogger(__name__)
 
diff --git a/evoprompt/task.py b/evoprompt/task.py
deleted file mode 100644
index d7367b5..0000000
--- a/evoprompt/task.py
+++ /dev/null
@@ -1,487 +0,0 @@
-import logging
-import re
-from abc import abstractmethod
-from argparse import Namespace
-from functools import lru_cache
-from statistics import mean
-from typing import Union
-
-from datasets import Dataset, load_dataset
-from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar, deque
-from tqdm import tqdm
-
-from evoprompt.cli import argument_parser
-from evoprompt.models import LLMModel
-from evoprompt.opt_types import ModelUsage
-from evoprompt.utils import log_calls
-
-logger = logging.getLogger(__name__)
-
-
-SYSTEM_MESSAGE = """
-You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-"""
-
-DatasetDatum = dict
-
-
-class EarlyStoppingMonitor:
-
-    @abstractmethod
-    def update(self, score: float) -> bool:
-        raise NotImplementedError
-
-
-class MomentBasedStopping(EarlyStoppingMonitor):
-    """
-    Watch the first derivative (moment) of the metric to determine when to stop.
-    """
-
-    def __init__(
-        self,
-        *,
-        patience: int = 10,
-        start_after: int = 20,
-        min_moment_magnitude: float = 0.001,
-    ):
-        self.patience = patience
-        self.start_after = start_after
-        self.min_moment_magnitude = min_moment_magnitude
-
-        self.moment_magnitudes = deque(maxlen=patience)
-        self.last_score = 0.0
-        self.num_calls = 0
-
-    def update(self, score: float) -> bool:
-        # caclulate the current moment (dx/dt)
-        self.num_calls += 1
-        if self.num_calls < self.start_after:
-            return False
-
-        self.moment_magnitudes.append(abs(score - self.last_score))
-        self.last_score = score
-        if len(self.moment_magnitudes) < self.patience:
-            return False
-
-        if mean(self.moment_magnitudes) < self.min_moment_magnitude:
-            return True
-
-        return False
-
-
-class ParentBaselineBasedStopping(EarlyStoppingMonitor):
-
-    def __init__(
-        self,
-        parent_histories: list[list[float]],
-        *,
-        patience: int = 10,
-        start_after: int = 20,
-        min_improvement: float = 0.001,
-    ):
-        self.parent_histories = parent_histories
-        self.patience = patience
-        self.start_after = start_after
-        self.min_improvement = min_improvement
-        self.num_calls = 0
-        self.improvement_memory = deque(maxlen=patience)
-
-    def update(self, score: float) -> bool:
-        self.num_calls += 1
-        if self.num_calls < self.start_after:
-            return False
-
-        parent_values = [  # get the metric value of the parents at the current step
-            (
-                parent_history[self.num_calls - 1]
-                if len(parent_history) >= self.num_calls
-                else parent_history[-1]  # extend with last value
-            )
-            for parent_history in self.parent_histories
-        ]
-        self.improvement_memory.append(
-            score - max(parent_values)  # compare with the best parent
-        )
-
-        if len(self.improvement_memory) < self.patience:
-            return False
-
-        if max(self.improvement_memory) < self.min_improvement:
-            # if the highest improvement is less than the minimum improvement, we stop
-            return True
-
-        return False
-
-
-class Task:
-    shorthand: str
-    validation_dataset: Dataset
-    test_dataset: Dataset
-
-    def __init__(
-        self,
-        model: Union[LLMModel],
-        validation_dataset: str,
-        test_dataset: str,
-        *,
-        use_grammar: bool,
-        no_early_stopping: bool = False,
-        evaluate_shortest_first: bool = False,
-        validation_split: str | None = None,
-        test_split: str | None = None,
-    ) -> None:
-        self.model = model
-
-        # whether we use the grammar to constrain the model output or not
-        self.use_grammar = use_grammar
-        self.no_early_stopping = no_early_stopping
-        self.evaluate_shortest_first = evaluate_shortest_first
-        if evaluate_shortest_first and no_early_stopping:
-            logger.warning(
-                "Both 'evaluate_shortest_first' and 'no_early_stopping' are set. This makes no sense"
-            )
-        self.validation_dataset = load_dataset(
-            validation_dataset, split=validation_split
-        )
-        self.test_dataset = load_dataset(test_dataset, split=test_split)
-
-    @abstractmethod
-    def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]:
-        pass
-
-    @abstractmethod
-    def _evaluate_sample(
-        self, prompt: str, datum: DatasetDatum
-    ) -> tuple[str, ModelUsage]:
-        pass
-
-    @abstractmethod
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    @abstractmethod
-    def _aggregate_result(self, results: list) -> float:
-        pass
-
-    def evaluate(
-        self,
-        prompt: str,
-        dataset: Dataset,
-        parent_histories: list[list[float]] | None = None,
-    ) -> tuple[float, ModelUsage, list[float]]:
-
-        early_stopping: EarlyStoppingMonitor
-        early_stopping_params = {
-            "patience": max(len(dataset) // 20, 5),
-            "start_after": max(len(dataset) // 5, 5),
-        }
-        if parent_histories is not None:
-            early_stopping = ParentBaselineBasedStopping(
-                parent_histories, **early_stopping_params
-            )
-        else:
-            early_stopping = MomentBasedStopping(**early_stopping_params)
-
-        results: list = []
-        if self.evaluate_shortest_first:
-            dataset = sorted(dataset, key=lambda x: len(self._get_text_for_datum(x)))
-        dataset_iterator: tqdm[DatasetDatum] = tqdm(
-            dataset, desc="evaluating prompt", leave=False
-        )
-        evaluation_usage = ModelUsage()
-        evaluation_history = []
-
-        for datum in dataset_iterator:
-            result, usage = self._evaluate_sample(prompt, datum)
-            results.append(result)
-            current_metric = self._aggregate_result(results)
-            dataset_iterator.set_postfix(
-                {self.metric_name: f"{current_metric*100:.1f}%"}
-            )
-            evaluation_usage += usage
-            evaluation_history.append(current_metric)
-            if not self.no_early_stopping and early_stopping.update(current_metric):
-                logger.info(
-                    f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
-                )
-                break
-        # input(f'F1 score: {current_metric:.2f}')
-
-        return self._aggregate_result(results), evaluation_usage, evaluation_history
-
-    @log_calls("Evaluating validation dataset")
-    def evaluate_validation(
-        self, prompt: str, parent_histories: list[list[float]] | None = None
-    ):
-        return self.evaluate(prompt, self.validation_dataset, parent_histories)
-
-    @log_calls("Evaluating test dataset")
-    def evaluate_test(self, prompt: str):
-        return self.evaluate(prompt, self.test_dataset)
-
-    @property
-    @abstractmethod
-    def metric_name(self) -> str:
-        pass
-
-    @property
-    @abstractmethod
-    def base_prompt(self) -> str:
-        pass
-
-
-# a simple grammar in GBNF notation which allows either positive or negative
-# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
-SA_OUTPUTS = set(("positive", "negative"))
-
-
-@lru_cache
-def sa_grammar_fn(verbose: bool = False):
-    return LlamaGrammar.from_string(
-        "root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)),
-        verbose=verbose,
-    )
-
-
-class SentimentAnalysis(Task):
-    shorthand = "sa"
-
-    def __init__(self, model, options: Namespace):
-        super().__init__(
-            model,
-            validation_dataset="SetFit/sst2",
-            test_dataset="SetFit/sst2",
-            use_grammar=options.use_grammar,
-            no_early_stopping=options.no_early_stopping,
-            evaluate_shortest_first=options.evaluate_shortest_first,
-            validation_split=f"validation[:{5 if options.debug else 200}]",
-            test_split="test[:20]" if options.debug else "test",
-        )
-
-    def predict(self, prompt: str, text: str):
-        # run model for inference using grammar to constrain output
-        # TODO grammar also depends on prompt and vice-versa -> what are good labels?
-        response, usage = self.model(
-            system_message=SYSTEM_MESSAGE,
-            prompt=prompt,
-            prompt_appendix="\nInput: " + '"' + text + '"',
-            grammar=sa_grammar_fn() if self.use_grammar else None,
-        )
-
-        if not self.use_grammar:
-            # we postprocess the model output to return as answer
-            response = response.strip()
-        # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}")
-
-        return response, usage
-
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        return datum["text"]
-
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
-        sst2_labels = {"negative": 0, "positive": 1}
-
-        response, usage = self.predict(
-            prompt=prompt, text=self._get_text_for_datum(datum)
-        )
-        response = response.lower()
-        if self.use_grammar:
-            # model output is from label space
-            answer_label = sst2_labels[response]
-        else:
-            answer_label = None
-            for label in sst2_labels.keys():
-                if label in response:
-                    answer_label = sst2_labels[label]
-                    break
-            else:
-                logger.warning(f"Invalid answer: {response}")
-                return "failed", usage
-
-        classification_result = (
-            "incorrect" if answer_label != datum["label"] else "correct"
-        )
-        return classification_result, usage
-
-    def _aggregate_result(self, results: list[str]) -> float:
-        num_correct_results = sum(1 for result in results if result == "correct")
-        accuracy = num_correct_results / len(results)
-        return accuracy
-
-    @property
-    def metric_name(self):
-        return "accuracy"
-
-    @property
-    def base_prompt(self):
-        #  from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
-        return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text."""
-
-
-def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'):
-    if len(sequence) == 0:
-        return ""
-
-    if quote is None:
-        quote = ""
-
-    escape_item = lambda text: text.replace('"', '\\"')
-
-    grammar = quote + escape_item(sequence[0]) + quote
-    for idx, item in enumerate(sequence[1:]):
-        new_item = "(" + quote + escape_item(item) + quote + ")?"
-        if idx == 0:
-            grammar += new_item
-        else:
-            grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :]
-    return grammar
-
-
-# a grammar in GBNF notation which takes the context into account
-# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
-@lru_cache
-def extractive_qa_grammar_fn(context: str, verbose: bool = False):
-    grammar_str = "(\n"
-    tokens = re.split(r"(\W+)", context.strip())[:-1]
-
-    for i in range(len(tokens)):
-        if i > 0:
-            grammar_str += "\n|\n"
-        grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")"
-    grammar_str += "\n)"
-    return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose)
-
-
-class QuestionAnswering(Task):
-    shorthand = "qa"
-
-    def __init__(self, model, options: Namespace):
-
-        self.metric = load_metric("squad")
-
-        super().__init__(
-            model,
-            "squad",
-            "squad",
-            use_grammar=options.use_grammar,
-            no_early_stopping=options.no_early_stopping,
-            evaluate_shortest_first=options.evaluate_shortest_first,
-            validation_split=f"train[:{5 if options.debug else 200}]",
-            test_split="validation[:20]" if options.debug else "validation",
-        )
-
-    def predict(self, prompt: str, datum: DatasetDatum):
-        # run model for inference
-        grammar = None
-        if self.use_grammar:
-            # context-sensitive grammar
-            context = datum["context"]
-            try:
-                grammar = extractive_qa_grammar_fn(context)
-            except Exception as e:
-                logger.exception(
-                    "Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s",
-                    len(context),
-                    context,
-                    exc_info=e,
-                )
-
-        response, usage = self.model(
-            system_message=SYSTEM_MESSAGE,
-            prompt=prompt,
-            prompt_appendix=self._get_text_for_datum(datum),
-            grammar=grammar,
-        )
-
-        if not self.use_grammar:
-            # we postprocess the model output to return as answer
-            response = response.strip()
-
-        return response, usage
-
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        return (
-            "\nContext: "
-            + '"'
-            + datum["context"]
-            + '"'
-            + "\nQuestion: "
-            + '"'
-            + datum["question"]
-            + '"'
-        )
-
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
-        answer, usage = self.predict(prompt, datum)
-        # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}')
-        # TODO check if answer is lower-cased in metric computation
-
-        result = self.metric.compute(
-            predictions=[{"prediction_text": answer, "id": datum["id"]}],
-            references=[{"answers": datum["answers"], "id": datum["id"]}],
-        )
-
-        return result["f1"] / 100, usage
-
-    def _aggregate_result(self, results: list[float]) -> float:
-        return sum(results) / len(results)
-
-    def evaluate(
-        self,
-        prompt: str,
-        dataset: Dataset,
-        parent_histories: list[list[float]] | None = None,
-    ):
-        def replace_symbol_for_grammar(sample: DatasetDatum):
-            symbol_replacement_mapping = {
-                "\u2013": "-",
-                "\u2014": "-",
-            }
-            symbol_replacement_mapping = dict(
-                (re.escape(k), v) for k, v in symbol_replacement_mapping.items()
-            )
-            symbol_replacement_pattern = re.compile(
-                "|".join(symbol_replacement_mapping.keys())
-            )
-            replace_fn = lambda text: symbol_replacement_pattern.sub(
-                lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text
-            )
-            sample["context"] = replace_fn(sample["context"])
-            sample["answers"]["text"] = [
-                replace_fn(text) for text in sample["answers"]["text"]
-            ]
-            return sample
-
-        if self.use_grammar:
-            # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen)
-            dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols")
-        return super().evaluate(prompt, dataset, parent_histories=parent_histories)
-
-    @property
-    def metric_name(self):
-        return "f1"
-
-    @property
-    def base_prompt(self):
-        # TODO find good prompt
-        return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
-
-
-tasks = {task.shorthand: task for task in Task.__subclasses__()}
-
-
-def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
-    if name not in tasks:
-        raise ValueError("Model %s does not exist", name)
-    return tasks[name](evaluation_model, options)
-
-
-argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
-argument_group = argument_parser.add_argument_group("Task arguments")
-argument_group.add_argument("--use-grammar", "-g", action="store_true")
-argument_group.add_argument("--no-early-stopping", action="store_true")
-argument_group.add_argument("--evaluate-shortest-first", action="store_true")
-argument_group.add_argument(
-    "--task", "-t", type=str, required=True, choices=tasks.keys()
-)
diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py
new file mode 100644
index 0000000..f810994
--- /dev/null
+++ b/evoprompt/task/__init__.py
@@ -0,0 +1,33 @@
+from argparse import Namespace
+
+from evoprompt.cli import argument_parser
+from evoprompt.models import LLMModel
+from evoprompt.task.question_answering import QuestionAnswering
+from evoprompt.task.sentiment_analysis import SentimentAnalysis
+
+# make sure to run definitions of subclasses of Task first
+from evoprompt.task.task import Task
+from evoprompt.utils import get_all_subclasses
+
+# at this point, we assume that all subclasses of Task have been defined
+tasks = {
+    task.shorthand: task
+    for task in get_all_subclasses(Task)
+    if hasattr(task, "shorthand")
+}
+
+
+def get_task(name: str, evaluation_model: LLMModel, **options):
+    if name not in tasks:
+        raise ValueError("Model %s does not exist", name)
+    return tasks[name](evaluation_model, **options)
+
+
+argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
+argument_group = argument_parser.add_argument_group("Task arguments")
+argument_group.add_argument("--use-grammar", "-g", action="store_true")
+argument_group.add_argument("--no-early-stopping", action="store_true")
+argument_group.add_argument("--evaluate-shortest-first", action="store_true")
+argument_group.add_argument(
+    "--task", "-t", type=str, required=True, choices=tasks.keys()
+)
diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py
new file mode 100644
index 0000000..25c9aac
--- /dev/null
+++ b/evoprompt/task/question_answering.py
@@ -0,0 +1,160 @@
+import logging
+import re
+from argparse import Namespace
+from functools import lru_cache
+
+from datasets import Dataset
+from evaluate import load as load_metric
+from llama_cpp import LlamaGrammar
+
+from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task
+
+logger = logging.getLogger(__name__)
+
+
+def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'):
+    if len(sequence) == 0:
+        return ""
+
+    if quote is None:
+        quote = ""
+
+    escape_item = lambda text: text.replace('"', '\\"')
+
+    grammar = quote + escape_item(sequence[0]) + quote
+    for idx, item in enumerate(sequence[1:]):
+        new_item = "(" + quote + escape_item(item) + quote + ")?"
+        if idx == 0:
+            grammar += new_item
+        else:
+            grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :]
+    return grammar
+
+
+# a grammar in GBNF notation which takes the context into account
+# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
+@lru_cache
+def extractive_qa_grammar_fn(context: str, verbose: bool = False):
+    grammar_str = "(\n"
+    tokens = re.split(r"(\W+)", context.strip())[:-1]
+
+    for i in range(len(tokens)):
+        if i > 0:
+            grammar_str += "\n|\n"
+        grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")"
+    grammar_str += "\n)"
+    return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose)
+
+
+class QuestionAnswering(Task):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.metric = load_metric("squad")
+
+    def predict(self, prompt: str, question: str, context: str):
+        # run model for inference
+        grammar = None
+        if self.use_grammar:
+            # context-sensitive grammar
+            context = context
+            try:
+                grammar = extractive_qa_grammar_fn(context)
+            except Exception as e:
+                logger.exception(
+                    "Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s",
+                    len(context),
+                    context,
+                    exc_info=e,
+                )
+
+        response, usage = self.model(
+            system_message=SYSTEM_MESSAGE,
+            prompt=prompt,
+            prompt_appendix=self._get_text_for_datum(question, context),
+            grammar=grammar,
+        )
+
+        if not self.use_grammar:
+            # we postprocess the model output to return as answer
+            response = response.strip()
+
+        return response, usage
+
+    def _get_text_for_datum(self, question: str, context: str) -> str:
+        return (
+            "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"'
+        )
+
+    def _evaluate_qa_sample(
+        self, prompt: str, id: str, question, context, gold_answers: list[str]
+    ):
+        answer, usage = self.predict(prompt, question, context)
+        # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}')
+        # TODO check if answer is lower-cased in metric computation
+
+        result = self.metric.compute(
+            predictions=[{"prediction_text": answer, "id": id}],
+            references=[{"answers": gold_answers, "id": id}],
+        )
+        return result["f1"] / 100, usage
+
+    def _aggregate_result(self, results: list[float]) -> float:
+        return sum(results) / len(results)
+
+    def evaluate(
+        self,
+        prompt: str,
+        dataset: Dataset,
+        parent_histories: list[list[float]] | None = None,
+    ):
+        def replace_symbol_for_grammar(sample: DatasetDatum):
+            symbol_replacement_mapping = {
+                "\u2013": "-",
+                "\u2014": "-",
+            }
+            symbol_replacement_mapping = dict(
+                (re.escape(k), v) for k, v in symbol_replacement_mapping.items()
+            )
+            symbol_replacement_pattern = re.compile(
+                "|".join(symbol_replacement_mapping.keys())
+            )
+            replace_fn = lambda text: symbol_replacement_pattern.sub(
+                lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text
+            )
+            sample["context"] = replace_fn(sample["context"])
+            sample["answers"]["text"] = [
+                replace_fn(text) for text in sample["answers"]["text"]
+            ]
+            return sample
+
+        if self.use_grammar:
+            # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen)
+            dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols")
+        return super().evaluate(prompt, dataset, parent_histories=parent_histories)
+
+    @property
+    def metric_name(self):
+        return "f1"
+
+    @property
+    def base_prompt(self):
+        # TODO find good prompt
+        return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
+
+
+class SQuAD(QuestionAnswering):
+    shorthand = "squad"
+
+    def load_validation_set(
+        self, validation_dataset: str | None, validation_split: str | None
+    ):
+        return super().load_validation_set("squad", "train[:200]")
+
+    def load_test_set(self, test_dataset: str, test_split: str | None):
+        return super().load_test_set("squad", "validation")
+
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        return self._evaluate_qa_sample(
+            prompt, datum["id"], datum["question"], datum["context"], datum["answers"]
+        )
diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
new file mode 100644
index 0000000..a31a4fc
--- /dev/null
+++ b/evoprompt/task/sentiment_analysis.py
@@ -0,0 +1,97 @@
+import logging
+from argparse import Namespace
+from functools import lru_cache
+
+from llama_cpp import LlamaGrammar
+
+from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task
+
+logger = logging.getLogger(__name__)
+
+
+# a simple grammar in GBNF notation which allows either positive or negative
+# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
+SA_OUTPUTS = set(("positive", "negative"))
+
+
+@lru_cache
+def sa_grammar_fn(verbose: bool = False):
+    return LlamaGrammar.from_string(
+        "root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)),
+        verbose=verbose,
+    )
+
+
+class SentimentAnalysis(Task):
+    shorthand = "sa"
+
+    def predict(self, prompt: str, text: str):
+        # run model for inference using grammar to constrain output
+        # TODO grammar also depends on prompt and vice-versa -> what are good labels?
+        response, usage = self.model(
+            system_message=SYSTEM_MESSAGE,
+            prompt=prompt,
+            prompt_appendix="\nInput: " + '"' + text + '"',
+            grammar=sa_grammar_fn() if self.use_grammar else None,
+        )
+
+        if not self.use_grammar:
+            # we postprocess the model output to return as answer
+            response = response.strip()
+        # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}")
+
+        return response, usage
+
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        return datum["text"]
+
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        sst2_labels = {"negative": 0, "positive": 1}
+
+        response, usage = self.predict(
+            prompt=prompt, text=self._get_text_for_datum(datum)
+        )
+        response = response.lower()
+        if self.use_grammar:
+            # model output is from label space
+            answer_label = sst2_labels[response]
+        else:
+            answer_label = None
+            for label in sst2_labels.keys():
+                if label in response:
+                    answer_label = sst2_labels[label]
+                    break
+            else:
+                logger.warning(f"Invalid answer: {response}")
+                return "failed", usage
+
+        classification_result = (
+            "incorrect" if answer_label != datum["label"] else "correct"
+        )
+        return classification_result, usage
+
+    def _aggregate_result(self, results: list[str]) -> float:
+        num_correct_results = sum(1 for result in results if result == "correct")
+        accuracy = num_correct_results / len(results)
+        return accuracy
+
+    @property
+    def metric_name(self):
+        return "accuracy"
+
+    @property
+    def base_prompt(self):
+        #  from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
+        return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text."""
+
+
+class SST2(SentimentAnalysis):
+    shorthand = "sa"
+
+    def load_validation_set(
+        self, validation_dataset: str | None, validation_split: str | None
+    ):
+        return super().load_validation_set("SetFit/sst2", "validation[:200]")
+
+    def load_test_set(self, test_dataset: str, test_split: str | None):
+        return super().load_test_set("SetFit/sst2", "test")
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
new file mode 100644
index 0000000..40c3bd0
--- /dev/null
+++ b/evoprompt/task/task.py
@@ -0,0 +1,245 @@
+import logging
+import re
+from abc import ABCMeta, abstractmethod
+from argparse import Namespace
+from functools import lru_cache
+from statistics import mean
+from typing import Union
+
+from datasets import Dataset, load_dataset
+from llama_cpp import LlamaGrammar, deque
+from tqdm import tqdm
+
+from evoprompt.models import LLMModel
+from evoprompt.opt_types import ModelUsage
+from evoprompt.utils import log_calls
+
+logger = logging.getLogger(__name__)
+
+
+SYSTEM_MESSAGE = """
+You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+"""
+
+DatasetDatum = dict
+
+
+class EarlyStoppingMonitor:
+
+    @abstractmethod
+    def update(self, score: float) -> bool:
+        raise NotImplementedError
+
+
+class MomentBasedStopping(EarlyStoppingMonitor):
+    """
+    Watch the first derivative (moment) of the metric to determine when to stop.
+    """
+
+    def __init__(
+        self,
+        *,
+        patience: int = 10,
+        start_after: int = 20,
+        min_moment_magnitude: float = 0.001,
+    ):
+        self.patience = patience
+        self.start_after = start_after
+        self.min_moment_magnitude = min_moment_magnitude
+
+        self.moment_magnitudes = deque(maxlen=patience)
+        self.last_score = 0.0
+        self.num_calls = 0
+
+    def update(self, score: float) -> bool:
+        # caclulate the current moment (dx/dt)
+        self.num_calls += 1
+        if self.num_calls < self.start_after:
+            return False
+
+        self.moment_magnitudes.append(abs(score - self.last_score))
+        self.last_score = score
+        if len(self.moment_magnitudes) < self.patience:
+            return False
+
+        if mean(self.moment_magnitudes) < self.min_moment_magnitude:
+            return True
+
+        return False
+
+
+class ParentBaselineBasedStopping(EarlyStoppingMonitor):
+
+    def __init__(
+        self,
+        parent_histories: list[list[float]],
+        *,
+        patience: int = 10,
+        start_after: int = 20,
+        min_improvement: float = 0.001,
+    ):
+        self.parent_histories = parent_histories
+        self.patience = patience
+        self.start_after = start_after
+        self.min_improvement = min_improvement
+        self.num_calls = 0
+        self.improvement_memory = deque(maxlen=patience)
+
+    def update(self, score: float) -> bool:
+        self.num_calls += 1
+        if self.num_calls < self.start_after:
+            return False
+
+        parent_values = [  # get the metric value of the parents at the current step
+            (
+                parent_history[self.num_calls - 1]
+                if len(parent_history) >= self.num_calls
+                else parent_history[-1]  # extend with last value
+            )
+            for parent_history in self.parent_histories
+        ]
+        self.improvement_memory.append(
+            score - max(parent_values)  # compare with the best parent
+        )
+
+        if len(self.improvement_memory) < self.patience:
+            return False
+
+        if max(self.improvement_memory) < self.min_improvement:
+            # if the highest improvement is less than the minimum improvement, we stop
+            return True
+
+        return False
+
+
+class Task(metaclass=ABCMeta):
+    shorthand: str
+    validation_dataset: Dataset
+    test_dataset: Dataset
+
+    def __init__(
+        self,
+        model: Union[LLMModel],
+        validation_dataset: str | None = None,
+        test_dataset: str | None = None,
+        *,
+        use_grammar: bool,
+        no_early_stopping: bool = False,
+        evaluate_shortest_first: bool = False,
+        validation_split: str | None = None,
+        test_split: str | None = None,
+        debug: bool = False,
+        **kwargs,
+    ) -> None:
+        self.model = model
+        self.debug = debug
+        # whether we use the grammar to constrain the model output or not
+        self.use_grammar = use_grammar
+        self.no_early_stopping = no_early_stopping
+        self.evaluate_shortest_first = evaluate_shortest_first
+        if evaluate_shortest_first and no_early_stopping:
+            logger.warning(
+                "Both 'evaluate_shortest_first' and 'no_early_stopping' are set. This makes no sense"
+            )
+        self.validation_dataset = self.load_validation_set(
+            validation_dataset, validation_split
+        )
+        self.test_dataset = self.load_test_set(test_dataset, test_split)
+
+    def load_validation_set(
+        self, validation_dataset: str, validation_split: str | None
+    ):
+        dataset = load_dataset(validation_dataset, split=validation_split)
+        if self.debug and len(dataset) > 5:
+            dataset = dataset.select(range(5))
+        return dataset
+
+    def load_test_set(self, test_dataset: str, test_split: str | None):
+        dataset = load_dataset(test_dataset, split=test_split)
+        if self.debug and len(dataset) > 20:
+            dataset = dataset.select(range(20))
+        return dataset
+
+    @abstractmethod
+    def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]:
+        pass
+
+    @abstractmethod
+    def _evaluate_sample(
+        self, prompt: str, datum: DatasetDatum
+    ) -> tuple[str, ModelUsage]:
+        pass
+
+    @abstractmethod
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        pass
+
+    @abstractmethod
+    def _aggregate_result(self, results: list) -> float:
+        pass
+
+    def evaluate(
+        self,
+        prompt: str,
+        dataset: Dataset,
+        parent_histories: list[list[float]] | None = None,
+    ) -> tuple[float, ModelUsage, list[float]]:
+
+        early_stopping: EarlyStoppingMonitor
+        early_stopping_params = {
+            "patience": max(len(dataset) // 20, 5),
+            "start_after": max(len(dataset) // 5, 5),
+        }
+        if parent_histories is not None:
+            early_stopping = ParentBaselineBasedStopping(
+                parent_histories, **early_stopping_params
+            )
+        else:
+            early_stopping = MomentBasedStopping(**early_stopping_params)
+
+        results: list = []
+        if self.evaluate_shortest_first:
+            dataset = sorted(dataset, key=lambda x: len(self._get_text_for_datum(x)))
+        dataset_iterator: tqdm[DatasetDatum] = tqdm(
+            dataset, desc="evaluating prompt", leave=False
+        )
+        evaluation_usage = ModelUsage()
+        evaluation_history = []
+
+        for datum in dataset_iterator:
+            result, usage = self._evaluate_sample(prompt, datum)
+            results.append(result)
+            current_metric = self._aggregate_result(results)
+            dataset_iterator.set_postfix(
+                {self.metric_name: f"{current_metric*100:.1f}%"}
+            )
+            evaluation_usage += usage
+            evaluation_history.append(current_metric)
+            if not self.no_early_stopping and early_stopping.update(current_metric):
+                logger.info(
+                    f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
+                )
+                break
+        # input(f'F1 score: {current_metric:.2f}')
+
+        return self._aggregate_result(results), evaluation_usage, evaluation_history
+
+    @log_calls("Evaluating validation dataset")
+    def evaluate_validation(
+        self, prompt: str, parent_histories: list[list[float]] | None = None
+    ):
+        return self.evaluate(prompt, self.validation_dataset, parent_histories)
+
+    @log_calls("Evaluating test dataset")
+    def evaluate_test(self, prompt: str):
+        return self.evaluate(prompt, self.test_dataset)
+
+    @property
+    @abstractmethod
+    def metric_name(self) -> str:
+        pass
+
+    @property
+    @abstractmethod
+    def base_prompt(self) -> str:
+        pass
diff --git a/evoprompt/utils.py b/evoprompt/utils.py
index ac8e6fe..3eefa2e 100644
--- a/evoprompt/utils.py
+++ b/evoprompt/utils.py
@@ -140,3 +140,8 @@ class log_calls:
 
             arguments[argument_name] = value
         return arguments
+
+def get_all_subclasses(cls):
+    return set(cls.__subclasses__()).union(
+        [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
+    )
\ No newline at end of file
diff --git a/main.py b/main.py
index 6382480..6a3d6b5 100644
--- a/main.py
+++ b/main.py
@@ -72,7 +72,7 @@ if __name__ == "__main__":
         case "openai":
             evaluation_model = Llama(options)
 
-    task = get_task(options.task, evaluation_model, options)
+    task = get_task(options.task, evaluation_model, **options.__dict__)
     logger.info(
         f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}"
     )
-- 
GitLab