From dd98676d971c7d367f1d965d899705a3da0715a2 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Tue, 3 Sep 2024 17:16:01 +0200
Subject: [PATCH] Only load models once

---
 evoprompt/models.py | 44 +++++++++++++++++++++++++++++---------------
 main.py             | 39 +++++++++++++++++++++++----------------
 2 files changed, 52 insertions(+), 31 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index a10005d..cc5d24e 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -24,19 +24,33 @@ ChatMessages = list[dict[str, str]]
 
 
 class LLMModel(ABC):
-    models: ClassVar[dict[str, type["LLMModel"]]] = {}
+    registered_models: ClassVar[dict[str, type["LLMModel"]]] = {}
+    # keeps track of loaded models so that we can reuse them instead of reloading them
+    loaded_models: ClassVar[dict[type["LLMModel"], ("LLMModel", str)]] = {}
 
     def __init_subclass__(cls) -> None:
         if inspect.isabstract(cls):
             return
-        cls.models[cls.__name__.lower()] = cls
+        LLMModel.registered_models[cls.__name__.lower()] = cls
         cls.register_arguments(argument_parser)
 
     @classmethod
     def get_model(cls, name: str, options: Namespace, **kwargs):
-        if name not in cls.models:
+        if name not in LLMModel.registered_models:
             raise ValueError("Model %s does not exist", name)
-        return cls.models[name](options=options, **kwargs)
+
+        key = cls.get_options_kwargs_hash(options, kwargs)
+        # check if model is already loaded
+        if cls in LLMModel.loaded_models:
+            model, model_key = LLMModel.loaded_models[cls]
+            if model_key != key:
+                raise ValueError(
+                    f"Model {model} is already loaded with different arguments"
+                )
+        else:
+            model = LLMModel.registered_models[name](options=options, **kwargs)
+            LLMModel.loaded_models[cls] = (model, key)
+        return model
 
     def __init__(self, options: Namespace, **kwargs):
         self.usage = ModelUsage()
@@ -100,21 +114,21 @@ class LLMModel(ABC):
         else:
             return model_completion_fn(**kwargs)
 
-    @property
-    def model_cache_key(self):
+    @staticmethod
+    def get_options_kwargs_hash(options: Namespace, kwargs):
         unique_options_key = json.dumps(
-            vars(self.options),
-            sort_keys=True,
-        ) + json.dumps(
-            self.kwargs,
+            (vars(options), kwargs),
             sort_keys=True,
         )
-        cache_key = (
+        return hashlib.sha1(unique_options_key.encode()).hexdigest()
+
+    @property
+    def model_cache_key(self):
+        return (
             str(self.model_name).replace("/", "_")
             + "/"
-            + hashlib.sha1(unique_options_key.encode()).hexdigest()
+            + self.get_options_kwargs_hash(self.options, self.kwargs)
         )
-        return cache_key
 
     @classmethod
     @abstractmethod
@@ -344,14 +358,14 @@ argument_group = argument_parser.add_argument_group("Model arguments")
 argument_group.add_argument(
     "--evolution-engine",
     type=str,
-    choices=LLMModel.models.keys(),
+    choices=LLMModel.registered_models.keys(),
     default="llama",
 )
 argument_group.add_argument(
     "--judge-engine",
     type=str,
     default=None,
-    choices=LLMModel.models.keys(),
+    choices=LLMModel.registered_models.keys(),
 )
 argument_group.add_argument(
     "--disable-cache",
diff --git a/main.py b/main.py
index 2b36dd7..3b87b48 100644
--- a/main.py
+++ b/main.py
@@ -39,6 +39,16 @@ if __name__ == "__main__":
     )
     options = argument_parser.parse_args()
 
+    # we only allow to specify one llama model, so it does not make sense to have both llama and llamachat models loaded; evaluation engine always adapts to loaded models
+    if options.judge_engine == "llama" and options.evolution_engine == "llamachat":
+        raise ValueError(
+            "Judge engine cannot be 'llama' when evolution engine is 'llamachat'"
+        )
+    if options.judge_engine == "llamachat" and options.evolution_engine == "llama":
+        raise ValueError(
+            "Judge engine cannot be 'llamachat' when evolution engine is 'llama'"
+        )
+
     # set up console logging and rnd
     setup_console_logger(verbosity_level=options.verbose)
     init_rng(options.seed)
@@ -61,13 +71,11 @@ if __name__ == "__main__":
     if debug:
         logger.info("DEBUG mode: Do a quick run")
 
-    # # set up evolution model
+    # set up evolution model
     evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
-    match options.evolution_engine:
-        case "llama":
-            logger.info("Using Llama as the evolution engine")
-        case "openai":
-            logger.info(f"Using {options.openai_model} as the evolution engine")
+    logger.info(
+        f"Using {evolution_model.__class__.__name__.lower()} as the evolution engine"
+    )
 
     judge_model: LLMModel | None = None
     if options.judge_engine is not None:
@@ -75,17 +83,16 @@ if __name__ == "__main__":
         logger.info(f"Using {options.judge_engine} as the judge engine")
 
     # set up evaluation model
-    # NOTE currenty we always stick to Llama as evaluation engine
+    # NOTE currenty we always stick to Llama (Llama or LlamaChat depending on evolution engine) as evaluation engine
     # TODO allow to set separate engine and model for evaluation?
-    logger.info("Using Llama as the evaluation engine")
-    evaluation_model: LLMModel
-    match options.evolution_engine:
-        case "llama" | "llamachat":
-            evaluation_model = evolution_model
-        case "openai":
-            evaluation_model = Llama(options)
-        case "openaichat":
-            evaluation_model = LlamaChat(options)
+    if isinstance(evolution_model, (Llama, LlamaChat)):
+        evaluation_model_name = evolution_model.__class__.__name__.lower()
+    elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)):
+        evaluation_model_name = judge_model.__class__.__name__.lower()
+    else:
+        evaluation_model_name = "llamachat"
+    evaluation_model = LLMModel.get_model(name=evaluation_model_name, options=options)
+    logger.info(f"Using {evaluation_model_name} as the evaluation engine")
 
     task = get_task(options.task, evaluation_model, **options.__dict__)
     logger.info(f"Running with task {task.__class__.__name__}")
-- 
GitLab