Skip to content
Snippets Groups Projects
Commit dd98676d authored by Max Kimmich's avatar Max Kimmich
Browse files

Only load models once

parent 31c10c46
No related branches found
No related tags found
No related merge requests found
......@@ -24,19 +24,33 @@ ChatMessages = list[dict[str, str]]
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
registered_models: ClassVar[dict[str, type["LLMModel"]]] = {}
# keeps track of loaded models so that we can reuse them instead of reloading them
loaded_models: ClassVar[dict[type["LLMModel"], ("LLMModel", str)]] = {}
def __init_subclass__(cls) -> None:
if inspect.isabstract(cls):
return
cls.models[cls.__name__.lower()] = cls
LLMModel.registered_models[cls.__name__.lower()] = cls
cls.register_arguments(argument_parser)
@classmethod
def get_model(cls, name: str, options: Namespace, **kwargs):
if name not in cls.models:
if name not in LLMModel.registered_models:
raise ValueError("Model %s does not exist", name)
return cls.models[name](options=options, **kwargs)
key = cls.get_options_kwargs_hash(options, kwargs)
# check if model is already loaded
if cls in LLMModel.loaded_models:
model, model_key = LLMModel.loaded_models[cls]
if model_key != key:
raise ValueError(
f"Model {model} is already loaded with different arguments"
)
else:
model = LLMModel.registered_models[name](options=options, **kwargs)
LLMModel.loaded_models[cls] = (model, key)
return model
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
......@@ -100,21 +114,21 @@ class LLMModel(ABC):
else:
return model_completion_fn(**kwargs)
@property
def model_cache_key(self):
@staticmethod
def get_options_kwargs_hash(options: Namespace, kwargs):
unique_options_key = json.dumps(
vars(self.options),
sort_keys=True,
) + json.dumps(
self.kwargs,
(vars(options), kwargs),
sort_keys=True,
)
cache_key = (
return hashlib.sha1(unique_options_key.encode()).hexdigest()
@property
def model_cache_key(self):
return (
str(self.model_name).replace("/", "_")
+ "/"
+ hashlib.sha1(unique_options_key.encode()).hexdigest()
+ self.get_options_kwargs_hash(self.options, self.kwargs)
)
return cache_key
@classmethod
@abstractmethod
......@@ -344,14 +358,14 @@ argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
"--evolution-engine",
type=str,
choices=LLMModel.models.keys(),
choices=LLMModel.registered_models.keys(),
default="llama",
)
argument_group.add_argument(
"--judge-engine",
type=str,
default=None,
choices=LLMModel.models.keys(),
choices=LLMModel.registered_models.keys(),
)
argument_group.add_argument(
"--disable-cache",
......
......@@ -39,6 +39,16 @@ if __name__ == "__main__":
)
options = argument_parser.parse_args()
# we only allow to specify one llama model, so it does not make sense to have both llama and llamachat models loaded; evaluation engine always adapts to loaded models
if options.judge_engine == "llama" and options.evolution_engine == "llamachat":
raise ValueError(
"Judge engine cannot be 'llama' when evolution engine is 'llamachat'"
)
if options.judge_engine == "llamachat" and options.evolution_engine == "llama":
raise ValueError(
"Judge engine cannot be 'llamachat' when evolution engine is 'llama'"
)
# set up console logging and rnd
setup_console_logger(verbosity_level=options.verbose)
init_rng(options.seed)
......@@ -61,13 +71,11 @@ if __name__ == "__main__":
if debug:
logger.info("DEBUG mode: Do a quick run")
# # set up evolution model
# set up evolution model
evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
match options.evolution_engine:
case "llama":
logger.info("Using Llama as the evolution engine")
case "openai":
logger.info(f"Using {options.openai_model} as the evolution engine")
logger.info(
f"Using {evolution_model.__class__.__name__.lower()} as the evolution engine"
)
judge_model: LLMModel | None = None
if options.judge_engine is not None:
......@@ -75,17 +83,16 @@ if __name__ == "__main__":
logger.info(f"Using {options.judge_engine} as the judge engine")
# set up evaluation model
# NOTE currenty we always stick to Llama as evaluation engine
# NOTE currenty we always stick to Llama (Llama or LlamaChat depending on evolution engine) as evaluation engine
# TODO allow to set separate engine and model for evaluation?
logger.info("Using Llama as the evaluation engine")
evaluation_model: LLMModel
match options.evolution_engine:
case "llama" | "llamachat":
evaluation_model = evolution_model
case "openai":
evaluation_model = Llama(options)
case "openaichat":
evaluation_model = LlamaChat(options)
if isinstance(evolution_model, (Llama, LlamaChat)):
evaluation_model_name = evolution_model.__class__.__name__.lower()
elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)):
evaluation_model_name = judge_model.__class__.__name__.lower()
else:
evaluation_model_name = "llamachat"
evaluation_model = LLMModel.get_model(name=evaluation_model_name, options=options)
logger.info(f"Using {evaluation_model_name} as the evaluation engine")
task = get_task(options.task, evaluation_model, **options.__dict__)
logger.info(f"Running with task {task.__class__.__name__}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment