diff --git a/cli.py b/cli.py
index ab05d658a20c4392677707429d9265eed26cb2e0..071c220d14580eeed8cbf0263a894ddd80790ee8 100644
--- a/cli.py
+++ b/cli.py
@@ -1,14 +1,31 @@
 from argparse import ArgumentParser
 
+from models import MODELS
+
+# from typing import Literal
+
+# import pydantic
+
+# class Arguments(pydantic.BaseModel):
+#     # Required Args
+#     evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine")
+#     evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm")
+#     # model_path: str = pydantic.Field(description="the model path")
+#     task: Literal['sa', 'qa'] = pydantic.Field(description="the model path")
+#     use_grammar: bool = pydantic.Field(False, description="use grammar")
+#     debug: bool = pydantic.Field(None, description="use grammar")
+#     chat: bool = pydantic.Field(False, description="use chat mode")
+
+
 argument_parser = ArgumentParser()
 
 argument_parser.add_argument(
-    "--evolution-engine", "-e", type=str, choices=["openai", "llama2"], default="llama2"
+    "--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2"
 )
 argument_parser.add_argument(
     "--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga"
 )
-argument_parser.add_argument("--model-path", "-m", type=str, required=True)
+argument_parser.add_argument("--model", "-m", type=str)
 argument_parser.add_argument(
     "--task", "-t", type=str, required=True, choices=["sa", "qa"]
 )
diff --git a/evolution.py b/evolution.py
index 9a935fac9ede1c7056f65cd42836835f607100ef..200a1b5bdd261c0190898bcb026a540b7dc37a16 100644
--- a/evolution.py
+++ b/evolution.py
@@ -3,10 +3,10 @@ from abc import abstractmethod
 from models import LLMModel
 from numpy.random import choice
 from opt_types import ModelUsage, Prompt
-from optimization import PromptOptimization
+from optimization import PromptOptimization, save_snapshot
 from task import Task
 from tqdm import trange
-from utils import initialize_run_directory, log_calls, logger, save_snapshot
+from utils import initialize_run_directory, log_calls, logger
 
 SYSTEM_MESSAGE = (
     "Please follow the instruction step-by-step to generate a better prompt."
@@ -100,7 +100,9 @@ class EvolutionAlgorithm(PromptOptimization):
 
         run_directory = initialize_run_directory(self.evolution_model)
 
-        initial_prompts, evolution_usage, evaluation_usage = self.init_run(self.population_size)
+        initial_prompts, evolution_usage, evaluation_usage = self.init_run(
+            self.population_size
+        )
         total_evaluation_usage += evaluation_usage
         total_evolution_usage += evolution_usage
 
@@ -166,7 +168,12 @@ class EvolutionAlgorithm(PromptOptimization):
         # We pick the prompt with the highest score on the development set and report its score on the testset.
         test_performance, _ = self.task.evaluate_test(p.content)
         logger.info("Best prompt on test set: %s", test_performance)
-        logger.info("Usage (evolution model / evaluation model / total): %s / %s / %s", total_evolution_usage, total_evaluation_usage, total_evolution_usage + total_evaluation_usage)
+        logger.info(
+            "Usage (evolution model / evaluation model / total): %s / %s / %s",
+            total_evolution_usage,
+            total_evaluation_usage,
+            total_evolution_usage + total_evaluation_usage,
+        )
 
         return total_evolution_usage, total_evaluation_usage
 
diff --git a/main.py b/main.py
index 2095276a34e462f8f4489e69a9d0983247964081..4a139592384d88bce4fb0da3786aff2718a9143d 100644
--- a/main.py
+++ b/main.py
@@ -4,7 +4,7 @@ from typing import Any
 from cli import argument_parser
 from dotenv import load_dotenv
 from evolution import DifferentialEvolution, GeneticAlgorithm
-from models import Llama2, OpenAI
+from models import Llama2, get_model_init
 from task import QuestionAnswering, SentimentAnalysis
 from utils import logger
 
@@ -25,26 +25,29 @@ if __name__ == "__main__":
     options = argument_parser.parse_args()
 
     # set up evolution model
+    model_init_fn = get_model_init(options.evolution_engine)
+    evolution_model = model_init_fn(
+        options.model,
+        chat=options.chat,
+    )
+
     match options.evolution_engine:
         case "llama2":
             logger.info("Using Llama2 client as the evolution engine")
-            evolution_model = Llama2(
-                model_path=options.model_path,
-                chat=options.chat,
-            )
-
         case "openai":
             logger.info("Using OpenAI client as the evolution engine")
-            evolution_model = OpenAI("gpt-3.5-turbo", chat=options.chat)
+
 
     # set up evaluation model
-    # NOTE currenty we always stick to Llama2 as evaluation model
+    # NOTE currenty we always stick to Llama2 as evaluation engine
+    # TODO allow to set separate engine and model for evaluation?
+    logger.info("Using Llama2 client as the evaluation engine")
     match options.evolution_engine:
         case "llama2":
             evaluation_model = evolution_model
         case "openai":
             evaluation_model = Llama2(
-                model_path=options.model_path,
+                model_path=options.model,
                 chat=options.chat,
             )
 
diff --git a/models.py b/models.py
index 486801546244e0c28ce8de309a91a8f6ff866125..f6f9fd6afa94004914f98e4db8f003df25f044a6 100644
--- a/models.py
+++ b/models.py
@@ -1,6 +1,6 @@
 from abc import abstractmethod
 from pathlib import Path
-from typing import Any
+from typing import Any, Callable
 
 import openai
 from llama_cpp import Llama
@@ -9,6 +9,31 @@ from opt_types import ModelUsage
 current_directory = Path(__file__).resolve().parent
 
 
+MODELS = {}
+
+
+def register_model(name: str):
+    def wrapper(f: Callable):
+        global MODELS
+        if name in MODELS:
+            raise ValueError("Cannot register model class %s: already exists", name)
+        MODELS[name] = f
+
+        return f
+
+    return wrapper
+
+
+def get_models():
+    return MODELS
+
+
+def get_model_init(name: str):
+    if name not in MODELS:
+        raise ValueError("Model %s does not exist", name)
+    return MODELS[name]
+
+
 class LLMModel:
     chat: bool
     model: Any
@@ -35,6 +60,7 @@ class LLMModel:
         pass
 
 
+@register_model("llama2")
 class Llama2(LLMModel):
     """Loads and queries a Llama2 model."""
 
@@ -117,11 +143,16 @@ class Llama2(LLMModel):
         return response_text, usage
 
 
+@register_model("openai")
 class OpenAI(LLMModel):
     """Queries an OpenAI model using its API."""
 
     def __init__(
-        self, model: str, chat: bool = False, verbose: bool = False, **kwargs
+        self,
+        model: str = "gpt-3.5-turbo",
+        chat: bool = False,
+        verbose: bool = False,
+        **kwargs,
     ) -> None:
         super().__init__(chat, model)
 
diff --git a/optimization.py b/optimization.py
index d379907ecca62cdb5ccc01b1af48b5cf2d90f07a..9f18365c07f93fedd8753f5a082bb6dff17e11d0 100644
--- a/optimization.py
+++ b/optimization.py
@@ -1,7 +1,10 @@
+import json
 from itertools import zip_longest
+from pathlib import Path
+from typing import Any
 
-from models import LLMModel
-from opt_types import ModelUsage, Prompt
+from models import Llama2, LLMModel, OpenAI
+from opt_types import ModelUsage, OptTypeEncoder, Prompt
 from task import Task
 from utils import log_calls
 
@@ -89,7 +92,9 @@ class PromptOptimization:
     def get_prompts(self, prompt_ids: list[str]):
         return [self.get_prompt(p_id) for p_id in prompt_ids]
 
-    def init_run(self, num_initial_prompts: int) -> tuple[list[Prompt], ModelUsage, ModelUsage]:
+    def init_run(
+        self, num_initial_prompts: int
+    ) -> tuple[list[Prompt], ModelUsage, ModelUsage]:
         # - Initial prompts P0 = {p1, p2, . . . , pN }
         paraphrases, paraphrase_usage = paraphrase_prompts(
             self.evolution_model, self.task.base_prompt, n=num_initial_prompts - 1
@@ -107,3 +112,58 @@ class PromptOptimization:
             evaluation_usage += prompt.usage
 
         return initial_prompts, paraphrase_usage, evaluation_usage
+
+
+# TODO turn snapshots methods into instance methods of optimizer
+def save_snapshot(
+    run_directory: Path,
+    all_prompts: list[Prompt],
+    family_tree: dict[str, tuple[str, str] | None],
+    P: list[list[str]],
+    T: int,
+    N: int,
+    task,
+    model: Llama2 | OpenAI,
+    evaluation_usage: ModelUsage,
+    evolution_usage: ModelUsage,
+    run_options: dict[str, Any],
+):
+
+    with open(run_directory / "snapshot.json", "w") as f:
+        json.dump(
+            {
+                "all_prompts": all_prompts,
+                "family_tree": family_tree,
+                "P": P,
+                "T": T,
+                "N": N,
+                "task": {
+                    "name": task.__class__.__name__,
+                    "validation_dataset": task.validation_dataset.info.dataset_name,
+                    "test_dataset": task.test_dataset.info.dataset_name,
+                    "metric": task.metric_name,
+                    "use_grammar": task.use_grammar,
+                },
+                "model": {"name": model.__class__.__name__},
+                "evaluation_usage": evaluation_usage,
+                "evolution_usage": evolution_usage,
+                "run_options": run_options,
+            },
+            f,
+            indent=4,
+            cls=OptTypeEncoder,
+        )
+
+
+def load_snapshot(path: Path):
+    import json
+
+    with path.open("r") as f:
+        snapshot = json.load(f)
+    return (
+        snapshot["family_tree"],
+        snapshot["P"],
+        snapshot["S"],
+        snapshot["T"],
+        snapshot["N"],
+    )
diff --git a/task.py b/task.py
index 3c2dc0eecc2d065a6430c1290f1033a925c7206b..461348d35be3b41dbb79986ca5df87e47640056b 100644
--- a/task.py
+++ b/task.py
@@ -8,8 +8,9 @@ from datasets import Dataset, load_dataset
 from evaluate import load as load_metric
 from llama_cpp import LlamaGrammar
 from models import Llama2, OpenAI
+from opt_types import ModelUsage
 from tqdm import tqdm
-from utils import ModelUsage, log_calls, logger
+from utils import log_calls, logger
 
 SYSTEM_MESSAGE = """
 You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
diff --git a/utils.py b/utils.py
index 3f4781de4f347fda537b1ceae42b5f7287405a59..c6c014bc5eaf766535fc452c89fefab3db2785c4 100644
--- a/utils.py
+++ b/utils.py
@@ -1,5 +1,4 @@
 import inspect
-import json
 import logging
 import os
 import re
@@ -10,9 +9,6 @@ from textwrap import dedent, indent
 from typing import Any, Callable
 from uuid import uuid4
 
-from models import Llama2, OpenAI
-from opt_types import ModelUsage, OptTypeEncoder, Prompt
-
 current_directory = Path(__file__).resolve().parent
 logger = logging.getLogger("test-classifier")
 logger.setLevel(level=logging.DEBUG)
@@ -36,7 +32,7 @@ Only return the name without any text before or after.""".strip()
 RUNS_DIR = current_directory / "runs"
 
 
-def initialize_run_directory(model: OpenAI | Llama2):
+def initialize_run_directory(model: Callable):
     response, usage = model(None, run_name_prompt, chat=True)
     model.usage -= usage
     run_name_match = re.search(r"^\w+$", response, re.MULTILINE)
@@ -126,57 +122,3 @@ class log_calls:
 
             arguments[argument_name] = value
         return arguments
-
-
-def save_snapshot(
-    run_directory: Path,
-    all_prompts: list[Prompt],
-    family_tree: dict[str, tuple[str, str] | None],
-    P: list[list[str]],
-    T: int,
-    N: int,
-    task,
-    model: Llama2 | OpenAI,
-    evaluation_usage: ModelUsage,
-    evolution_usage: ModelUsage,
-    run_options: dict[str, Any],
-):
-
-    with open(run_directory / "snapshot.json", "w") as f:
-        json.dump(
-            {
-                "all_prompts": all_prompts,
-                "family_tree": family_tree,
-                "P": P,
-                "T": T,
-                "N": N,
-                "task": {
-                    "name": task.__class__.__name__,
-                    "validation_dataset": task.validation_dataset.info.dataset_name,
-                    "test_dataset": task.test_dataset.info.dataset_name,
-                    "metric": task.metric_name,
-                    "use_grammar": task.use_grammar,
-                },
-                "model": {"name": model.__class__.__name__},
-                "evaluation_usage": evaluation_usage,
-                "evolution_usage": evolution_usage,
-                "run_options": run_options,
-            },
-            f,
-            indent=4,
-            cls=OptTypeEncoder,
-        )
-
-
-def load_snapshot(path: Path):
-    import json
-
-    with path.open("r") as f:
-        snapshot = json.load(f)
-    return (
-        snapshot["family_tree"],
-        snapshot["P"],
-        snapshot["S"],
-        snapshot["T"],
-        snapshot["N"],
-    )