diff --git a/cli.py b/cli.py
index 89a875c5f5edad08ab53fba75eefc3ea3d771d19..cc167df835beb3070a1d5942b8e127b5c6a5d0df 100644
--- a/cli.py
+++ b/cli.py
@@ -1,37 +1,3 @@
 from argparse import ArgumentParser
 
-from models import MODELS
-
-# from typing import Literal
-
-# import pydantic
-
-# class Arguments(pydantic.BaseModel):
-#     # Required Args
-#     evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine")
-#     evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm")
-#     # model_path: str = pydantic.Field(description="the model path")
-#     task: Literal['sa', 'qa'] = pydantic.Field(description="the model path")
-#     use_grammar: bool = pydantic.Field(False, description="use grammar")
-#     debug: bool = pydantic.Field(None, description="use grammar")
-#     chat: bool = pydantic.Field(False, description="use chat mode")
-
-
 argument_parser = ArgumentParser()
-
-argument_parser.add_argument(
-    "--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2"
-)
-argument_parser.add_argument(
-    "--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga"
-)
-argument_parser.add_argument(
-    "--task", "-t", type=str, required=True, choices=["sa", "qa"]
-)
-argument_parser.add_argument("--use-grammar", "-g", action="store_true")
-argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
-argument_parser.add_argument("--chat", "-c", action="store_true")
-argument_parser.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
-argument_parser.add_argument(
-    "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
-)
diff --git a/evolution.py b/evolution.py
index 200a1b5bdd261c0190898bcb026a540b7dc37a16..4bf51eb424bf6600d0662826f45221c7ff67768c 100644
--- a/evolution.py
+++ b/evolution.py
@@ -1,11 +1,13 @@
 from abc import abstractmethod
 
-from models import LLMModel
 from numpy.random import choice
+from tqdm import trange
+
+from cli import argument_parser
+from models import LLMModel
 from opt_types import ModelUsage, Prompt
 from optimization import PromptOptimization, save_snapshot
 from task import Task
-from tqdm import trange
 from utils import initialize_run_directory, log_calls, logger
 
 SYSTEM_MESSAGE = (
@@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt}
 
 
 class EvolutionAlgorithm(PromptOptimization):
-    # TODO add docstrings
+    shorthand: str
+
     """The super class for all evolution algorithms containing shared parameters."""
 
     def __init__(
@@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization):
 class GeneticAlgorithm(EvolutionAlgorithm):
     """The genetic algorithm implemented using LLMs."""
 
+    shorthand = "ga"
+
     @log_calls("Performing prompt evolution using GA")
     def evolve(
         self,
@@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm):
 class DifferentialEvolution(EvolutionAlgorithm):
     """The genetic algorithm implemented using LLMs."""
 
+    shorthand = "de"
+
     @log_calls("Performing prompt evolution using GA")
     def evolve(
         self,
@@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm):
             )
         ]
         return population
+
+
+optimizers = {
+    algorithm.shorthand: algorithm for algorithm in EvolutionAlgorithm.__subclasses__()
+}
+
+
+def get_optimizer_class(name: str):
+    if name not in optimizers:
+        raise ValueError("Optimization Algorithm %s does not exist", name)
+    return optimizers[name]
+
+
+argument_parser.add_argument(
+    "--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga"
+)
diff --git a/main.py b/main.py
index 674d80c7a617c83b1390a96ff3d1d79918622537..eca0a65a5cf1189f429dc51669f082767cf5e0b4 100644
--- a/main.py
+++ b/main.py
@@ -4,9 +4,9 @@ from typing import Any
 from dotenv import load_dotenv
 
 from cli import argument_parser
-from evolution import DifferentialEvolution, GeneticAlgorithm
-from models import Llama2, get_model_init
-from task import QuestionAnswering, SentimentAnalysis
+from evolution import DifferentialEvolution, GeneticAlgorithm, get_optimizer_class
+from models import Llama2, LLMModel, OpenAI, get_model
+from task import QuestionAnswering, SentimentAnalysis, get_task
 from utils import logger
 
 load_dotenv()
@@ -26,8 +26,7 @@ if __name__ == "__main__":
     options = argument_parser.parse_args()
 
     # set up evolution model
-    model_init_fn = get_model_init(options.evolution_engine)
-    evolution_model = model_init_fn(options)
+    evolution_model = get_model(options.evolution_engine, options)
 
     match options.evolution_engine:
         case "llama2":
@@ -39,6 +38,7 @@ if __name__ == "__main__":
     # NOTE currenty we always stick to Llama2 as evaluation engine
     # TODO allow to set separate engine and model for evaluation?
     logger.info("Using Llama2 as the evaluation engine")
+    evaluation_model: LLMModel
     match options.evolution_engine:
         case "llama2":
             evaluation_model = evolution_model
@@ -61,40 +61,14 @@ if __name__ == "__main__":
                 f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG."
             )
 
-    match options.task:
-        case "sa":
-            logger.info("Running with task sentiment analysis on dataset SetFit/sst2")
-            task = SentimentAnalysis(
-                evaluation_model,
-                "SetFit/sst2",
-                "SetFit/sst2",
-                use_grammar=options.use_grammar,
-                validation_split=f"validation[:{5 if debug else 200}]",
-                test_split="test[:20]" if debug else "test",
-            )
-        case "qa":
-            logger.info("Running with task question answering on dataset squad")
-            task = QuestionAnswering(
-                evaluation_model,
-                "squad",
-                "squad",
-                use_grammar=options.use_grammar,
-                validation_split=f"train[:{5 if debug else 200}]",
-                test_split="validation[:20]" if debug else "validation",
-            )
-        case _:
-            raise ValueError(
-                f"Task {options.task} does not exist. Choose from 'sa', 'qa'."
-            )
+    task = get_task(options.task, evaluation_model, options)
+    logger.info(
+        f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}"
+    )
 
     logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm)
 
-    # TODO allow to register algorithms and map to classes
-    if options.evolution_algorithm == "ga":
-        optimizer_class = GeneticAlgorithm
-    else:
-        optimizer_class = DifferentialEvolution
-
+    optimizer_class = get_optimizer_class(options.evolution_algorithm)
     optimizer = optimizer_class(
         population_size=10,
         task=task,
diff --git a/models.py b/models.py
index 56d50164aad09a93ed67e734f9cdd33fc379a264..2c4cad5f369d911605dd848ffaf9d0ceb24d2123 100644
--- a/models.py
+++ b/models.py
@@ -1,11 +1,12 @@
 from abc import abstractmethod
-from argparse import Namespace
+from argparse import ArgumentParser, Namespace
 from pathlib import Path
-from typing import Any, Callable
+from typing import Any
 
 import openai
 from llama_cpp import Llama
 
+from cli import argument_parser
 from opt_types import ModelUsage
 
 current_directory = Path(__file__).resolve().parent
@@ -34,33 +35,12 @@ class LLMModel:
     ) -> Any:
         pass
 
-
-MODELS: dict[str, type[LLMModel]] = {}
-
-
-def register_model(name: str):
-    def wrapper(f: type[LLMModel]):
-        global MODELS
-        if name in MODELS:
-            raise ValueError("Cannot register model class %s: already exists", name)
-        MODELS[name] = f
-
-        return f
-
-    return wrapper
-
-
-def get_models():
-    return MODELS
-
-
-def get_model_init(name: str):
-    if name not in MODELS:
-        raise ValueError("Model %s does not exist", name)
-    return MODELS[name]
+    @classmethod
+    @abstractmethod
+    def register_arguments(cls, parser: ArgumentParser):
+        pass
 
 
-@register_model("llama2")
 class Llama2(LLMModel):
     """Loads and queries a Llama2 model."""
 
@@ -142,8 +122,14 @@ class Llama2(LLMModel):
         self.usage += usage
         return response_text, usage
 
+    @classmethod
+    def register_arguments(cls, parser: ArgumentParser):
+        group = parser.add_argument_group("Llama2 model arguments")
+        group.add_argument(
+            "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
+        )
+
 
-@register_model("openai")
 class OpenAI(LLMModel):
     """Queries an OpenAI model using its API."""
 
@@ -215,3 +201,26 @@ class OpenAI(LLMModel):
             usage = ModelUsage(**response.usage.__dict__)
             self.usage += usage
             return response.choices[0].text, usage
+
+    @classmethod
+    def register_arguments(cls, parser: ArgumentParser):
+        group = parser.add_argument_group("OpenAI model arguments")
+        group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
+
+
+models = {model.__name__.lower(): model for model in LLMModel.__subclasses__()}
+for name, model in models.items():
+    model.register_arguments(argument_parser)
+
+
+def get_model(name: str, options: Namespace):
+    if name not in models:
+        raise ValueError("Model %s does not exist", name)
+    return models[name](options)
+
+
+argument_group = argument_parser.add_argument_group("Model arguments")
+argument_group.add_argument(
+    "--evolution-engine", "-e", type=str, choices=models.keys(), default="llama2"
+)
+argument_group.add_argument("--chat", "-c", action="store_true")
diff --git a/task.py b/task.py
index abbc554dd575297fd192e083691f22b9884ff2b6..adfaa1f12ceb9a69ab4a7214f0889de09a12bc08 100644
--- a/task.py
+++ b/task.py
@@ -1,5 +1,6 @@
 import re
 from abc import abstractmethod
+from argparse import Namespace
 from functools import lru_cache
 from typing import Union
 
@@ -8,7 +9,8 @@ from evaluate import load as load_metric
 from llama_cpp import LlamaGrammar
 from tqdm import tqdm
 
-from models import Llama2, OpenAI
+from cli import argument_parser
+from models import Llama2, LLMModel, OpenAI
 from opt_types import ModelUsage
 from utils import log_calls, logger
 
@@ -20,6 +22,7 @@ DatasetDatum = dict
 
 
 class Task:
+    shorthand: str
     validation_dataset: Dataset
     test_dataset: Dataset
 
@@ -67,9 +70,9 @@ class Task:
         for datum in dataset_iterator:
             result, usage = self._evaluate_sample(prompt, datum)
             results.append(result)
-            current_metrics = self._aggregate_result(results)
+            current_metric = self._aggregate_result(results)
             dataset_iterator.set_postfix(
-                {self.metric_name: f"{current_metrics*100:.1f}%"}
+                {self.metric_name: f"{current_metric*100:.1f}%"}
             )
             evaluation_usage += usage
 
@@ -109,24 +112,16 @@ def sa_grammar_fn(verbose: bool = False):
 
 
 class SentimentAnalysis(Task):
+    shorthand = "sa"
 
-    def __init__(
-        self,
-        model,
-        validation_dataset: str,
-        test_dataset: str,
-        *,
-        use_grammar: bool,
-        validation_split: str | None = None,
-        test_split: str | None = None,
-    ) -> None:
+    def __init__(self, model, options: Namespace):
         super().__init__(
             model,
-            validation_dataset,
-            test_dataset,
-            use_grammar=use_grammar,
-            validation_split=validation_split,
-            test_split=test_split,
+            validation_dataset="SetFit/sst2",
+            test_dataset="SetFit/sst2",
+            use_grammar=options.use_grammar,
+            validation_split=f"validation[:{5 if options.debug else 200}]",
+            test_split="test[:20]" if options.debug else "test",
         )
 
     def predict(self, prompt: str, text: str):
@@ -219,27 +214,19 @@ def qa_grammar_fn(context: str, verbose: bool = False):
 
 
 class QuestionAnswering(Task):
+    shorthand = "qa"
 
-    def __init__(
-        self,
-        model,
-        validation_dataset: str,
-        test_dataset: str,
-        *,
-        use_grammar: bool,
-        validation_split: str | None = None,
-        test_split: str | None = None,
-    ) -> None:
+    def __init__(self, model, options: Namespace):
 
         self.metric = load_metric("squad")
 
         super().__init__(
             model,
-            validation_dataset,
-            test_dataset,
-            use_grammar=use_grammar,
-            validation_split=validation_split,
-            test_split=test_split,
+            "squad",
+            "squad",
+            use_grammar=options.use_grammar,
+            validation_split=f"train[:{5 if options.debug else 200}]",
+            test_split="validation[:20]" if options.debug else "validation",
         )
 
     def predict(self, prompt: str, context: str, question: str):
@@ -330,3 +317,20 @@ class QuestionAnswering(Task):
     def base_prompt(self):
         # TODO find good prompt
         return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
+
+
+tasks = {task.shorthand: task for task in Task.__subclasses__()}
+
+
+def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
+    if name not in tasks:
+        raise ValueError("Model %s does not exist", name)
+    return tasks[name](evaluation_model, options)
+
+
+argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
+argument_group = argument_parser.add_argument_group("Task arguments")
+argument_group.add_argument("--use-grammar", "-g", action="store_true")
+argument_group.add_argument(
+    "--task", "-t", type=str, required=True, choices=["sa", "qa"]
+)