From dd98676d971c7d367f1d965d899705a3da0715a2 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 3 Sep 2024 17:16:01 +0200 Subject: [PATCH] Only load models once --- evoprompt/models.py | 44 +++++++++++++++++++++++++++++--------------- main.py | 39 +++++++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index a10005d..cc5d24e 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -24,19 +24,33 @@ ChatMessages = list[dict[str, str]] class LLMModel(ABC): - models: ClassVar[dict[str, type["LLMModel"]]] = {} + registered_models: ClassVar[dict[str, type["LLMModel"]]] = {} + # keeps track of loaded models so that we can reuse them instead of reloading them + loaded_models: ClassVar[dict[type["LLMModel"], ("LLMModel", str)]] = {} def __init_subclass__(cls) -> None: if inspect.isabstract(cls): return - cls.models[cls.__name__.lower()] = cls + LLMModel.registered_models[cls.__name__.lower()] = cls cls.register_arguments(argument_parser) @classmethod def get_model(cls, name: str, options: Namespace, **kwargs): - if name not in cls.models: + if name not in LLMModel.registered_models: raise ValueError("Model %s does not exist", name) - return cls.models[name](options=options, **kwargs) + + key = cls.get_options_kwargs_hash(options, kwargs) + # check if model is already loaded + if cls in LLMModel.loaded_models: + model, model_key = LLMModel.loaded_models[cls] + if model_key != key: + raise ValueError( + f"Model {model} is already loaded with different arguments" + ) + else: + model = LLMModel.registered_models[name](options=options, **kwargs) + LLMModel.loaded_models[cls] = (model, key) + return model def __init__(self, options: Namespace, **kwargs): self.usage = ModelUsage() @@ -100,21 +114,21 @@ class LLMModel(ABC): else: return model_completion_fn(**kwargs) - @property - def model_cache_key(self): + @staticmethod + def get_options_kwargs_hash(options: Namespace, kwargs): unique_options_key = json.dumps( - vars(self.options), - sort_keys=True, - ) + json.dumps( - self.kwargs, + (vars(options), kwargs), sort_keys=True, ) - cache_key = ( + return hashlib.sha1(unique_options_key.encode()).hexdigest() + + @property + def model_cache_key(self): + return ( str(self.model_name).replace("/", "_") + "/" - + hashlib.sha1(unique_options_key.encode()).hexdigest() + + self.get_options_kwargs_hash(self.options, self.kwargs) ) - return cache_key @classmethod @abstractmethod @@ -344,14 +358,14 @@ argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument( "--evolution-engine", type=str, - choices=LLMModel.models.keys(), + choices=LLMModel.registered_models.keys(), default="llama", ) argument_group.add_argument( "--judge-engine", type=str, default=None, - choices=LLMModel.models.keys(), + choices=LLMModel.registered_models.keys(), ) argument_group.add_argument( "--disable-cache", diff --git a/main.py b/main.py index 2b36dd7..3b87b48 100644 --- a/main.py +++ b/main.py @@ -39,6 +39,16 @@ if __name__ == "__main__": ) options = argument_parser.parse_args() + # we only allow to specify one llama model, so it does not make sense to have both llama and llamachat models loaded; evaluation engine always adapts to loaded models + if options.judge_engine == "llama" and options.evolution_engine == "llamachat": + raise ValueError( + "Judge engine cannot be 'llama' when evolution engine is 'llamachat'" + ) + if options.judge_engine == "llamachat" and options.evolution_engine == "llama": + raise ValueError( + "Judge engine cannot be 'llamachat' when evolution engine is 'llama'" + ) + # set up console logging and rnd setup_console_logger(verbosity_level=options.verbose) init_rng(options.seed) @@ -61,13 +71,11 @@ if __name__ == "__main__": if debug: logger.info("DEBUG mode: Do a quick run") - # # set up evolution model + # set up evolution model evolution_model = LLMModel.get_model(options.evolution_engine, options=options) - match options.evolution_engine: - case "llama": - logger.info("Using Llama as the evolution engine") - case "openai": - logger.info(f"Using {options.openai_model} as the evolution engine") + logger.info( + f"Using {evolution_model.__class__.__name__.lower()} as the evolution engine" + ) judge_model: LLMModel | None = None if options.judge_engine is not None: @@ -75,17 +83,16 @@ if __name__ == "__main__": logger.info(f"Using {options.judge_engine} as the judge engine") # set up evaluation model - # NOTE currenty we always stick to Llama as evaluation engine + # NOTE currenty we always stick to Llama (Llama or LlamaChat depending on evolution engine) as evaluation engine # TODO allow to set separate engine and model for evaluation? - logger.info("Using Llama as the evaluation engine") - evaluation_model: LLMModel - match options.evolution_engine: - case "llama" | "llamachat": - evaluation_model = evolution_model - case "openai": - evaluation_model = Llama(options) - case "openaichat": - evaluation_model = LlamaChat(options) + if isinstance(evolution_model, (Llama, LlamaChat)): + evaluation_model_name = evolution_model.__class__.__name__.lower() + elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)): + evaluation_model_name = judge_model.__class__.__name__.lower() + else: + evaluation_model_name = "llamachat" + evaluation_model = LLMModel.get_model(name=evaluation_model_name, options=options) + logger.info(f"Using {evaluation_model_name} as the evaluation engine") task = get_task(options.task, evaluation_model, **options.__dict__) logger.info(f"Running with task {task.__class__.__name__}") -- GitLab