diff --git a/cli.py b/cli.py
index 071c220d14580eeed8cbf0263a894ddd80790ee8..0e421f74ea03f3b6c52b16ef9ade2d489192a98f 100644
--- a/cli.py
+++ b/cli.py
@@ -1,34 +1,3 @@
 from argparse import ArgumentParser
 
-from models import MODELS
-
-# from typing import Literal
-
-# import pydantic
-
-# class Arguments(pydantic.BaseModel):
-#     # Required Args
-#     evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine")
-#     evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm")
-#     # model_path: str = pydantic.Field(description="the model path")
-#     task: Literal['sa', 'qa'] = pydantic.Field(description="the model path")
-#     use_grammar: bool = pydantic.Field(False, description="use grammar")
-#     debug: bool = pydantic.Field(None, description="use grammar")
-#     chat: bool = pydantic.Field(False, description="use chat mode")
-
-
-argument_parser = ArgumentParser()
-
-argument_parser.add_argument(
-    "--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2"
-)
-argument_parser.add_argument(
-    "--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga"
-)
-argument_parser.add_argument("--model", "-m", type=str)
-argument_parser.add_argument(
-    "--task", "-t", type=str, required=True, choices=["sa", "qa"]
-)
-argument_parser.add_argument("--use-grammar", "-g", action="store_true")
-argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
-argument_parser.add_argument("--chat", "-c", action="store_true")
+argument_parser = ArgumentParser(conflict_handler="resolve")
diff --git a/download_model.sh b/download_model.sh
index 3ef764bed81f136bc3fbd1258cb49c270cfc4a4f..5f6ab36572aba6bdb803933db2185110d925c94a 100644
--- a/download_model.sh
+++ b/download_model.sh
@@ -1,3 +1,4 @@
-TARGET_PATH='models/llama-2-13b-chat.Q5_K_M.gguf'
+TARGET_PATH='models/'
 mkdir -p "$(dirname "${TARGET_PATH}")"
-curl -L "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf?download=true" -o ${TARGET_PATH}
+curl -L "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf?download=true" -o ${TARGET_PATH}/llama-2-13b-chat.Q5_K_M.gguf
+curl -L "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/blob/main/Meta-Llama-3-8B.Q8_0.gguf?download=true" -o ${TARGET_PATH}/Meta-Llama-3-8B.Q8_0.gguf
diff --git a/evolution.py b/evolution.py
index 200a1b5bdd261c0190898bcb026a540b7dc37a16..d2c3e1fcef407ebc3c70e23c6420d9aeb7230062 100644
--- a/evolution.py
+++ b/evolution.py
@@ -1,11 +1,13 @@
 from abc import abstractmethod
 
-from models import LLMModel
 from numpy.random import choice
+from tqdm import trange
+
+from cli import argument_parser
+from models import LLMModel
 from opt_types import ModelUsage, Prompt
 from optimization import PromptOptimization, save_snapshot
 from task import Task
-from tqdm import trange
 from utils import initialize_run_directory, log_calls, logger
 
 SYSTEM_MESSAGE = (
@@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt}
 
 
 class EvolutionAlgorithm(PromptOptimization):
-    # TODO add docstrings
+    shorthand: str
+
     """The super class for all evolution algorithms containing shared parameters."""
 
     def __init__(
@@ -166,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization):
         logger.info(f"Best prompt: {p}")
 
         # We pick the prompt with the highest score on the development set and report its score on the testset.
-        test_performance, _ = self.task.evaluate_test(p.content)
+        test_performance, _, _ = self.task.evaluate_test(p.content)
         logger.info("Best prompt on test set: %s", test_performance)
         logger.info(
             "Usage (evolution model / evaluation model / total): %s / %s / %s",
@@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization):
 class GeneticAlgorithm(EvolutionAlgorithm):
     """The genetic algorithm implemented using LLMs."""
 
+    shorthand = "ga"
+
     @log_calls("Performing prompt evolution using GA")
     def evolve(
         self,
@@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm):
 class DifferentialEvolution(EvolutionAlgorithm):
     """The genetic algorithm implemented using LLMs."""
 
+    shorthand = "de"
+
     @log_calls("Performing prompt evolution using GA")
     def evolve(
         self,
@@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm):
             )
         ]
         return population
+
+
+optimizers = {
+    algorithm.shorthand: algorithm for algorithm in EvolutionAlgorithm.__subclasses__()
+}
+
+
+def get_optimizer_class(name: str):
+    if name not in optimizers:
+        raise ValueError("Optimization Algorithm %s does not exist", name)
+    return optimizers[name]
+
+
+argument_parser.add_argument(
+    "--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga"
+)
diff --git a/main.py b/main.py
index 4a139592384d88bce4fb0da3786aff2718a9143d..4ca5c5128953d3e98302bae41785f8e4c3774d46 100644
--- a/main.py
+++ b/main.py
@@ -1,11 +1,12 @@
 import os
 from typing import Any
 
-from cli import argument_parser
 from dotenv import load_dotenv
-from evolution import DifferentialEvolution, GeneticAlgorithm
-from models import Llama2, get_model_init
-from task import QuestionAnswering, SentimentAnalysis
+
+from cli import argument_parser
+from evolution import get_optimizer_class
+from models import Llama3, LLMModel
+from task import get_task
 from utils import logger
 
 load_dotenv()
@@ -25,31 +26,24 @@ if __name__ == "__main__":
     options = argument_parser.parse_args()
 
     # set up evolution model
-    model_init_fn = get_model_init(options.evolution_engine)
-    evolution_model = model_init_fn(
-        options.model,
-        chat=options.chat,
-    )
+    evolution_model = LLMModel.get_model(options.evolution_engine, options)
 
     match options.evolution_engine:
         case "llama2":
-            logger.info("Using Llama2 client as the evolution engine")
+            logger.info("Using Llama2 as the evolution engine")
         case "openai":
-            logger.info("Using OpenAI client as the evolution engine")
-
+            logger.info(f"Using {options.openai_model} as the evolution engine")
 
     # set up evaluation model
     # NOTE currenty we always stick to Llama2 as evaluation engine
     # TODO allow to set separate engine and model for evaluation?
-    logger.info("Using Llama2 client as the evaluation engine")
+    logger.info("Using Llama2 as the evaluation engine")
+    evaluation_model: LLMModel
     match options.evolution_engine:
-        case "llama2":
+        case "llama2" | "llama3":
             evaluation_model = evolution_model
         case "openai":
-            evaluation_model = Llama2(
-                model_path=options.model,
-                chat=options.chat,
-            )
+            evaluation_model = Llama3(options)
 
     # log cli arguments
     logger.info(
@@ -67,40 +61,14 @@ if __name__ == "__main__":
                 f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG."
             )
 
-    match options.task:
-        case "sa":
-            logger.info("Running with task sentiment analysis on dataset SetFit/sst2")
-            task = SentimentAnalysis(
-                evaluation_model,
-                "SetFit/sst2",
-                "SetFit/sst2",
-                use_grammar=options.use_grammar,
-                validation_split=f"validation[:{5 if debug else 200}]",
-                test_split="test[:20]" if debug else "test",
-            )
-        case "qa":
-            logger.info("Running with task question answering on dataset squad")
-            task = QuestionAnswering(
-                evaluation_model,
-                "squad",
-                "squad",
-                use_grammar=options.use_grammar,
-                validation_split=f"train[:{5 if debug else 200}]",
-                test_split="validation[:20]" if debug else "validation",
-            )
-        case _:
-            raise ValueError(
-                f"Task {options.task} does not exist. Choose from 'sa', 'qa'."
-            )
+    task = get_task(options.task, evaluation_model, options)
+    logger.info(
+        f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}"
+    )
 
     logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm)
 
-    # TODO allow to register algorithms and map to classes
-    if options.evolution_algorithm == "ga":
-        optimizer_class = GeneticAlgorithm
-    else:
-        optimizer_class = DifferentialEvolution
-
+    optimizer_class = get_optimizer_class(options.evolution_algorithm)
     optimizer = optimizer_class(
         population_size=10,
         task=task,
diff --git a/models.py b/models.py
index f6f9fd6afa94004914f98e4db8f003df25f044a6..bf0d49cad99956612e3ca2a060b8b3bc562d2a54 100644
--- a/models.py
+++ b/models.py
@@ -1,47 +1,38 @@
-from abc import abstractmethod
+import abc
+import inspect
+from abc import ABC, abstractmethod, abstractproperty
+from argparse import ArgumentParser, Namespace
 from pathlib import Path
-from typing import Any, Callable
+from typing import Any, ClassVar
 
 import openai
 from llama_cpp import Llama
+
+from cli import argument_parser
 from opt_types import ModelUsage
 
 current_directory = Path(__file__).resolve().parent
 
 
-MODELS = {}
-
-
-def register_model(name: str):
-    def wrapper(f: Callable):
-        global MODELS
-        if name in MODELS:
-            raise ValueError("Cannot register model class %s: already exists", name)
-        MODELS[name] = f
-
-        return f
-
-    return wrapper
-
-
-def get_models():
-    return MODELS
-
-
-def get_model_init(name: str):
-    if name not in MODELS:
-        raise ValueError("Model %s does not exist", name)
-    return MODELS[name]
+class LLMModel(ABC):
+    models: ClassVar[dict[str, type["LLMModel"]]] = {}
+    chat: bool
 
+    def __init_subclass__(cls) -> None:
+        if inspect.isabstract(cls):
+            return
+        cls.models[cls.__name__.lower()] = cls
+        cls.register_arguments(argument_parser)
 
-class LLMModel:
-    chat: bool
-    model: Any
+    @classmethod
+    def get_model(cls, name: str, options: Namespace):
+        if name not in cls.models:
+            raise ValueError("Model %s does not exist", name)
+        return cls.models[name](options)
 
-    def __init__(self, chat: bool, model: Any):
+    def __init__(self, options: Namespace):
         self.usage = ModelUsage()
-        self.chat = chat
-        self.model = model
+        self.chat = options.chat
 
     @abstractmethod
     def __call__(
@@ -59,15 +50,17 @@ class LLMModel:
     ) -> Any:
         pass
 
+    @classmethod
+    @abstractmethod
+    def register_arguments(cls, parser: ArgumentParser):
+        pass
+
 
-@register_model("llama2")
-class Llama2(LLMModel):
-    """Loads and queries a Llama2 model."""
+class LlamaModel(LLMModel):
 
     def __init__(
         self,
-        model_path: str,
-        chat: bool = False,
+        options: Namespace,
         n_gpu_layers: int = 60,
         n_threads: int = 8,
         n_ctx: int = 4096,
@@ -76,16 +69,17 @@ class Llama2(LLMModel):
     ) -> None:
 
         # initialize model
-        model = Llama(
-            model_path,
-            chat_format="llama-2",
+        self.model = Llama(
+            options.llama_path,
+            chat_format=self.chat_format,
             verbose=verbose,
             n_gpu_layers=n_gpu_layers,
             n_threads=n_threads,
             n_ctx=n_ctx,
             **kwargs,
         )
-        super().__init__(chat, model)
+
+        super().__init__(options)
 
     def __call__(
         self,
@@ -142,19 +136,42 @@ class Llama2(LLMModel):
         self.usage += usage
         return response_text, usage
 
+    @property
+    @abstractmethod
+    def chat_format(self) -> str:
+        pass
+
+    @classmethod
+    def register_arguments(cls, parser: ArgumentParser):
+        group = parser.add_argument_group(f"{cls.__name__} model arguments")
+        group.add_argument(
+            "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
+        )
+
+
+class Llama2(LlamaModel):
+    @property
+    def chat_format(self) -> str:
+        return "llama-2"
+
+
+class Llama3(LlamaModel):
+    @property
+    def chat_format(self) -> str:
+        return "llama-3"
+
 
-@register_model("openai")
 class OpenAI(LLMModel):
     """Queries an OpenAI model using its API."""
 
     def __init__(
         self,
-        model: str = "gpt-3.5-turbo",
-        chat: bool = False,
+        options: Namespace,
         verbose: bool = False,
         **kwargs,
     ) -> None:
-        super().__init__(chat, model)
+        self.model_name = options.openai_model
+        super().__init__(options)
 
         # initialize client for API calls
         self.openai_client = openai.OpenAI(**kwargs)
@@ -191,7 +208,7 @@ class OpenAI(LLMModel):
                     },
                 )
             response = self.openai_client.chat.completions.create(
-                model=self.model,
+                model=self.model_name,
                 messages=messages,
                 stop=stop,
                 max_tokens=max_tokens,
@@ -215,3 +232,19 @@ class OpenAI(LLMModel):
             usage = ModelUsage(**response.usage.__dict__)
             self.usage += usage
             return response.choices[0].text, usage
+
+    @classmethod
+    def register_arguments(cls, parser: ArgumentParser):
+        group = parser.add_argument_group("OpenAI model arguments")
+        group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
+
+
+argument_group = argument_parser.add_argument_group("Model arguments")
+argument_group.add_argument(
+    "--evolution-engine",
+    "-e",
+    type=str,
+    choices=LLMModel.models.keys(),
+    default="llama2",
+)
+argument_group.add_argument("--chat", "-c", action="store_true")
diff --git a/opt_types.py b/opt_types.py
index 3d46de80719d2ff86025b92880b51e04044b113f..a80c850ff7032e393cfbfc73c072769f857a670d 100644
--- a/opt_types.py
+++ b/opt_types.py
@@ -30,6 +30,7 @@ class Prompt:
     content: str
     score: float
     usage: ModelUsage
+    evaluation_history: list[float]
     meta: dict = field(default_factory=dict)
     id: str = field(default_factory=lambda: uuid4().hex)
 
diff --git a/optimization.py b/optimization.py
index 417c12cbebac486ab0d94bf5c8027c3ac893bb72..f5592fcf1bdac49076af5e80730c02a309e35378 100644
--- a/optimization.py
+++ b/optimization.py
@@ -57,22 +57,36 @@ class PromptOptimization:
     def reset(self):
         self._init
 
-    def evaluate_prompt(self, prompt: str):
-        return self.task.evaluate_validation(prompt)
+    def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None):
+        parent_histories = (
+            [parent.evaluation_history for parent in parents]
+            if parents is not None
+            else None
+        )
+        return self.task.evaluate_validation(prompt, parent_histories)
 
     def add_prompt(
-        self, prompt: str, parents: tuple[Prompt] = None, meta: dict = None
+        self,
+        prompt: str,
+        parents: tuple[Prompt] | None = None,
+        meta: dict | None = None,
     ) -> Prompt:
-        score, usage = self.evaluate_prompt(prompt)
-        prompt = Prompt(content=prompt, score=score, meta=meta, usage=usage)
+        score, usage, history = self.evaluate_prompt(prompt, parents)
+        prompt_object = Prompt(
+            content=prompt,
+            score=score,
+            meta=meta if meta is not None else {},
+            usage=usage,
+            evaluation_history=history,
+        )
 
         # keep track of prompt
-        self.all_prompts[prompt.id] = prompt
-        self.family_tree[prompt.id] = (
+        self.all_prompts[prompt_object.id] = prompt_object
+        self.family_tree[prompt_object.id] = (
             tuple(p.id for p in parents) if parents is not None else None
         )
 
-        return prompt
+        return prompt_object
 
     def add_prompts(
         self,
@@ -86,7 +100,7 @@ class PromptOptimization:
         ]
 
     def get_prompt(self, prompt_id: str):
-        return self.all_prompt[prompt_id]
+        return self.all_prompts[prompt_id]
 
     def get_prompts(self, prompt_ids: list[str]):
         return [self.get_prompt(p_id) for p_id in prompt_ids]
diff --git a/task.py b/task.py
index 461348d35be3b41dbb79986ca5df87e47640056b..d062253eeee5e38db3388d5879da7064d3afee4b 100644
--- a/task.py
+++ b/task.py
@@ -1,23 +1,121 @@
 import re
 from abc import abstractmethod
-from collections import defaultdict
+from argparse import Namespace
 from functools import lru_cache
-from typing import DefaultDict, Mapping, Union
+from statistics import mean
+from typing import Union
 
 from datasets import Dataset, load_dataset
 from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar
-from models import Llama2, OpenAI
-from opt_types import ModelUsage
+from llama_cpp import LlamaGrammar, deque
+from torch.utils import data
 from tqdm import tqdm
+
+from cli import argument_parser
+from models import Llama2, LLMModel, OpenAI
+from opt_types import ModelUsage
 from utils import log_calls, logger
 
 SYSTEM_MESSAGE = """
 You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 """
 
+DatasetDatum = dict
+
+
+class EarlyStoppingMonitor:
+
+    @abstractmethod
+    def update(self, score: float) -> bool:
+        raise NotImplementedError
+
+
+class MomentBasedStopping(EarlyStoppingMonitor):
+    """
+    Watch the first derivative (moment) of the metric to determine when to stop.
+    """
+
+    def __init__(
+        self,
+        *,
+        patience: int = 10,
+        start_after: int = 20,
+        min_moment_magnitude: float = 0.001,
+    ):
+        self.patience = patience
+        self.start_after = start_after
+        self.min_moment_magnitude = min_moment_magnitude
+
+        self.moment_magnitudes = deque(maxlen=patience)
+        self.last_score = 0.0
+        self.num_calls = 0
+
+    def update(self, score: float) -> bool:
+        # caclulate the current moment (dx/dt)
+        self.num_calls += 1
+        if self.num_calls < self.start_after:
+            return False
+
+        self.moment_magnitudes.append(abs(score - self.last_score))
+        self.last_score = score
+        if len(self.moment_magnitudes) < self.patience:
+            return False
+
+        if mean(self.moment_magnitudes) < self.min_moment_magnitude:
+            return True
+
+        return False
+
+
+class ParentBaselineBasedStopping(EarlyStoppingMonitor):
+
+    def __init__(
+        self,
+        parent_histories: list[list[float]],
+        *,
+        patience: int = 10,
+        start_after: int = 20,
+        min_improvement: float = 0.001,
+    ):
+        self.parent_histories = parent_histories
+        self.patience = patience
+        self.start_after = start_after
+        self.min_improvement = min_improvement
+        self.num_calls = 0
+        self.improvement_memory = deque(maxlen=patience)
+
+    def update(self, score: float) -> bool:
+        self.num_calls += 1
+        if self.num_calls < self.start_after:
+            return False
+
+        parent_values = [  # get the metric value of the parents at the current step
+            (
+                parent_history[self.num_calls - 1]
+                if len(parent_history) >= self.num_calls
+                else parent_history[-1]  # extend with last value
+            )
+            for parent_history in self.parent_histories
+        ]
+        self.improvement_memory.append(
+            score - max(parent_values)  # compare with the best parent
+        )
+
+        if len(self.improvement_memory) < self.patience:
+            return False
+
+        if max(self.improvement_memory) < self.min_improvement:
+            # if the highest improvement is less than the minimum improvement, we stop
+            return True
+
+        return False
+
 
 class Task:
+    shorthand: str
+    validation_dataset: Dataset
+    test_dataset: Dataset
+
     def __init__(
         self,
         model: Union[Llama2, OpenAI],
@@ -43,26 +141,76 @@ class Task:
         pass
 
     @abstractmethod
-    def _evaluate(self, prompt: str, dataset) -> tuple[float, ModelUsage]:
+    def _evaluate_sample(
+        self, prompt: str, datum: DatasetDatum
+    ) -> tuple[str, ModelUsage]:
         pass
 
+    @abstractmethod
+    def _aggregate_result(self, results: list) -> float:
+        pass
+
+    def evaluate(
+        self,
+        prompt: str,
+        dataset: Dataset,
+        parent_histories: list[list[float]] | None = None,
+    ) -> tuple[float, ModelUsage, list[float]]:
+
+        early_stopping: EarlyStoppingMonitor
+        early_stopping_params = {
+            "patience": max(len(dataset) // 20, 5),
+            "start_after": max(len(dataset) // 5, 5),
+        }
+        if parent_histories is not None:
+            early_stopping = ParentBaselineBasedStopping(
+                parent_histories, **early_stopping_params
+            )
+        else:
+            early_stopping = MomentBasedStopping(**early_stopping_params)
+
+        results: list = []
+        dataset_iterator: tqdm[DatasetDatum] = tqdm(
+            dataset, desc="evaluating prompt", leave=False
+        )
+        evaluation_usage = ModelUsage()
+        evaluation_history = []
+
+        for datum in dataset_iterator:
+            result, usage = self._evaluate_sample(prompt, datum)
+            results.append(result)
+            current_metric = self._aggregate_result(results)
+            dataset_iterator.set_postfix(
+                {self.metric_name: f"{current_metric*100:.1f}%"}
+            )
+            evaluation_usage += usage
+            evaluation_history.append(current_metric)
+            if early_stopping.update(current_metric):
+                logger.info(
+                    f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
+                )
+                break
+
+        return self._aggregate_result(results), evaluation_usage, evaluation_history
+
     @log_calls("Evaluating validation dataset")
-    @lru_cache(maxsize=None)
-    def evaluate_validation(self, prompt: str):
-        return self._evaluate(prompt, self.validation_dataset)
+    def evaluate_validation(
+        self, prompt: str, parent_histories: list[list[float]] | None = None
+    ):
+        return self.evaluate(prompt, self.validation_dataset, parent_histories)
 
     @log_calls("Evaluating test dataset")
     def evaluate_test(self, prompt: str):
-        return self._evaluate(prompt, self.test_dataset)
+        return self.evaluate(prompt, self.test_dataset)
 
     @property
     @abstractmethod
-    def metric_name(self):
+    def metric_name(self) -> str:
         pass
 
     @property
     @abstractmethod
-    def base_prompt(self):
+    def base_prompt(self) -> str:
         pass
 
 
@@ -80,24 +228,16 @@ def sa_grammar_fn(verbose: bool = False):
 
 
 class SentimentAnalysis(Task):
+    shorthand = "sa"
 
-    def __init__(
-        self,
-        model,
-        validation_dataset: str,
-        test_dataset: str,
-        *,
-        use_grammar: bool,
-        validation_split: str | None = None,
-        test_split: str | None = None,
-    ) -> None:
+    def __init__(self, model, options: Namespace):
         super().__init__(
             model,
-            validation_dataset,
-            test_dataset,
-            use_grammar=use_grammar,
-            validation_split=validation_split,
-            test_split=test_split,
+            validation_dataset="SetFit/sst2",
+            test_dataset="SetFit/sst2",
+            use_grammar=options.use_grammar,
+            validation_split=f"validation[:{5 if options.debug else 200}]",
+            test_split="test[:20]" if options.debug else "test",
         )
 
     def predict(self, prompt: str, text: str):
@@ -117,39 +257,33 @@ class SentimentAnalysis(Task):
 
         return response, usage
 
-    def _evaluate(self, prompt: str, dataset: Dataset):
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
         sst2_labels = {"negative": 0, "positive": 1}
 
-        results: DefaultDict[str, int] = defaultdict(int)
-        dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
-        evaluation_usage = ModelUsage()
-
-        for datum in dataset_iterator:
-            response, usage = self.predict(prompt=prompt, text=datum["text"])
-            response = response.lower()
-            evaluation_usage += usage
-            if self.use_grammar:
-                # model output is from label space
-                answer_label = sst2_labels[response]
+        response, usage = self.predict(prompt=prompt, text=datum["text"])
+        response = response.lower()
+        if self.use_grammar:
+            # model output is from label space
+            answer_label = sst2_labels[response]
+        else:
+            answer_label = None
+            for label in sst2_labels.keys():
+                if label in response:
+                    answer_label = sst2_labels[label]
+                    break
             else:
-                answer_label = None
-                for label in sst2_labels.keys():
-                    if label in response:
-                        answer_label = sst2_labels[label]
-                        break
-                else:
-                    logger.warning(f"Invalid answer: {response}")
-                    results["failed"] += 1
-                    continue
-
-            classification_result = (
-                "incorrect" if answer_label != datum["label"] else "correct"
-            )
-            results[classification_result] += 1
-            dataset_iterator.set_postfix(results)
+                logger.warning(f"Invalid answer: {response}")
+                return "failed", usage
+
+        classification_result = (
+            "incorrect" if answer_label != datum["label"] else "correct"
+        )
+        return classification_result, usage
 
-        accuracy = results["correct"] / sum(results.values())
-        return accuracy, evaluation_usage
+    def _aggregate_result(self, results: list[str]) -> float:
+        num_correct_results = sum(1 for result in results if result == "correct")
+        accuracy = num_correct_results / len(results)
+        return accuracy
 
     @property
     def metric_name(self):
@@ -196,27 +330,19 @@ def qa_grammar_fn(context: str, verbose: bool = False):
 
 
 class QuestionAnswering(Task):
+    shorthand = "qa"
 
-    def __init__(
-        self,
-        model,
-        validation_dataset: str,
-        test_dataset: str,
-        *,
-        use_grammar: bool,
-        validation_split: str | None = None,
-        test_split: str | None = None,
-    ) -> None:
+    def __init__(self, model, options: Namespace):
 
         self.metric = load_metric("squad")
 
         super().__init__(
             model,
-            validation_dataset,
-            test_dataset,
-            use_grammar=use_grammar,
-            validation_split=validation_split,
-            test_split=test_split,
+            "squad",
+            "squad",
+            use_grammar=options.use_grammar,
+            validation_split=f"train[:{5 if options.debug else 200}]",
+            test_split="validation[:20]" if options.debug else "validation",
         )
 
     def predict(self, prompt: str, context: str, question: str):
@@ -255,10 +381,26 @@ class QuestionAnswering(Task):
 
         return response, usage
 
-    def _evaluate(self, prompt: str, dataset: Dataset):
-        evaluation_usage = ModelUsage()
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        answer, usage = self.predict(
+            prompt,
+            context=datum["context"],
+            question=datum["question"],
+        )
+        # TODO check if answer is lower-cased in metric computation
+
+        result = self.metric.compute(
+            predictions=[{"prediction_text": answer, "id": datum["id"]}],
+            references=[{"answers": datum["answers"], "id": datum["id"]}],
+        )
+
+        return result["f1"] / 100, usage
+
+    def _aggregate_result(self, results: list[float]) -> float:
+        return sum(results) / len(results)
 
-        def replace_symbol_for_grammar(sample: Mapping):
+    def evaluate(self, prompt: str, dataset: Dataset):
+        def replace_symbol_for_grammar(sample: DatasetDatum):
             symbol_replacement_mapping = {
                 "\u2013": "-",
                 "\u2014": "-",
@@ -281,35 +423,7 @@ class QuestionAnswering(Task):
         if self.use_grammar:
             # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen)
             dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols")
-
-        dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
-
-        num_samples = 0
-        f1 = 0.0
-        em = 0
-        for datum in dataset_iterator:
-            answer, usage = self.predict(
-                prompt,
-                context=datum["context"],
-                question=datum["question"],
-            )
-            # TODO check if answer is lower-cased in metric computation
-
-            evaluation_usage += usage
-
-            num_samples += 1
-            result = self.metric.compute(
-                predictions=[{"prediction_text": answer, "id": datum["id"]}],
-                references=[{"answers": datum["answers"], "id": datum["id"]}],
-            )
-            f1 += result["f1"]
-            em += result["exact_match"]
-
-            dataset_iterator.set_postfix(
-                {"f1": f1 / num_samples, "em": em / num_samples}
-            )
-
-        return f1 / num_samples, evaluation_usage
+        return super().evaluate(prompt, dataset)
 
     @property
     def metric_name(self):
@@ -319,3 +433,20 @@ class QuestionAnswering(Task):
     def base_prompt(self):
         # TODO find good prompt
         return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
+
+
+tasks = {task.shorthand: task for task in Task.__subclasses__()}
+
+
+def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
+    if name not in tasks:
+        raise ValueError("Model %s does not exist", name)
+    return tasks[name](evaluation_model, options)
+
+
+argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
+argument_group = argument_parser.add_argument_group("Task arguments")
+argument_group.add_argument("--use-grammar", "-g", action="store_true")
+argument_group.add_argument(
+    "--task", "-t", type=str, required=True, choices=["sa", "qa"]
+)