diff --git a/evoprompt/models.py b/evoprompt/models.py index 3131ffff3fd20593eb07502324387f23a52d133f..f9221698f99656d8236a07c8e577cacfe7a483a2 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,11 +1,12 @@ -import functools +import hashlib import inspect +import json import logging +import warnings from abc import ABC, abstractmethod from argparse import ArgumentParser, Namespace from pathlib import Path from typing import Any, Callable, ClassVar -import warnings import llama_cpp import openai @@ -19,10 +20,11 @@ logger = logging.getLogger(__name__) logging.captureWarnings(True) warnings.simplefilter("once") +ChatMessages = list[dict[str, str]] + class LLMModel(ABC): models: ClassVar[dict[str, type["LLMModel"]]] = {} - chat: bool def __init_subclass__(cls) -> None: if inspect.isabstract(cls): @@ -36,26 +38,25 @@ class LLMModel(ABC): raise ValueError("Model %s does not exist", name) return cls.models[name](options=options, **kwargs) - @functools.lru_cache - def _compute_cache_key(self, name, **kwargs): - # we use a tuple of the model name, the options, and the kwargs as the cache key - return (name,) + tuple((key, value) for key, value in kwargs.items()) - def __init__(self, options: Namespace, **kwargs): self.usage = ModelUsage() - self.chat = options.chat # store kwargs for caching self.options = options self.kwargs = kwargs # set up caching for model calls + self._call_model_cached = None if not options.disable_cache: cache = Cache(Path(".cache_dir", self.model_cache_key)) - self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])( - self._call_model_cached - ) + @cache.memoize(typed=True, ignore=[0, "func"]) + def _call_function(func, *args, **kwargs): + return func(*args, **kwargs) + + self._call_model_cached = _call_function + + @abstractmethod def create_completion( self, system_message: str | None, @@ -65,48 +66,10 @@ class LLMModel(ABC): prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", - chat: bool | None = None, stop: str = None, - history: dict = None, + history: ChatMessages | None = None, **kwargs: Any, - ) -> tuple[str, ModelUsage]: - if chat is None: - chat = self.chat - max_tokens = kwargs.pop("max_tokens", self.options.max_tokens) - - # create prompt - prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix - - if not chat and system_message: - prompt = system_message + prompt - - messages = [self._get_user_message(prompt)] - - if chat: - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message in the history? - if history is not None: - messages = history + messages - elif system_message: - messages.insert( - 0, - self._get_system_message(system_message), - ) - model_input = {"messages": messages} - else: - model_input = {"prompt": prompt} - - reponse, usage = self._create_completion( - chat=chat, - **model_input, - stop=stop, - max_tokens=max_tokens, - use_cache=use_cache, - **kwargs, - ) - - messages.append(self._get_assistant_message(reponse)) - return reponse, messages, usage + ) -> tuple[str, ModelUsage]: ... def _get_user_message(self, content: str): return { @@ -134,18 +97,26 @@ class LLMModel(ABC): warnings.warn("Caching is disabled when a grammar is provided.") use_cache = False - if use_cache: - # use cached function call - cache_key = self._compute_cache_key( - model_completion_fn.__name__, **self.options.__dict__, **self.kwargs - ) - return self._call_model_cached(model_completion_fn, cache_key, **kwargs) + if use_cache and self._call_model_cached is not None: + return self._call_model_cached(model_completion_fn, **kwargs) else: return model_completion_fn(**kwargs) - def _call_model_cached(self, func, cache_key, *args, **kwargs): - # `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function - return func(*args, **kwargs) + @property + def model_cache_key(self): + unique_options_key = json.dumps( + vars(self.options), + sort_keys=True, + ) + json.dumps( + self.kwargs, + sort_keys=True, + ) + cache_key = ( + str(self.model_name).replace("/", "_") + + "/" + + hashlib.sha1(unique_options_key.encode()).hexdigest() + ) + return cache_key @classmethod @abstractmethod @@ -205,34 +176,51 @@ class Llama(LLMModel): # needs to be called after model is initialized super().__init__(options=options, n_ctx=n_ctx, **kwargs) + def create_completion( + self, + system_message: str | None, + prompt: str, + *, + use_cache: bool = False, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + stop: str = None, + history: ChatMessages | None = None, + **kwargs: Any, + ) -> tuple[str, ModelUsage]: + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] + if system_message is not None: + prompt = system_message + prompt + + reponse, usage = self._create_completion( + prompt=prompt, + stop=stop, + use_cache=use_cache, + max_tokens=self.options.max_tokens, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, messages, usage + def _create_completion( self, - chat: bool, use_cache: bool = False, **kwargs, ): - if chat: - response = self._call_model( - self.model.create_chat_completion, - use_cache=use_cache, - **kwargs, - ) - response_text = response["choices"][0]["message"]["content"] - else: - response = self._call_model( - self.model.create_completion, - use_cache=use_cache, - **kwargs, - ) - response_text = response["choices"][0]["text"] + response = self._call_model( + self.model.create_completion, + use_cache=use_cache, + **kwargs, + ) + response_text = response["choices"][0]["text"] usage = ModelUsage(**response["usage"]) return response_text, usage - @property - def model_cache_key(self): - return self.model_name - @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group(f"{cls.__name__} model arguments") @@ -272,7 +260,61 @@ class Llama(LLMModel): ) -class OpenAI(LLMModel): +class ChatModel: + + def create_completion( + self, + system_message: str | None, + prompt: str, + *, + use_cache: bool = False, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + stop: str = None, + history: ChatMessages | None = None, + **kwargs: Any, + ) -> tuple[str, ModelUsage]: + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] + + # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case + # TODO is it better to check for a system message in the history? + if history is None and system_message: + history = [self._get_system_message(system_message)] + + reponse, usage = self._create_completion( + messages=messages, + stop=stop, + use_cache=use_cache, + max_tokens=self.options.max_tokens, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, history + messages, usage + + +class LlamaChat(ChatModel, Llama): + + def _create_completion( + self, + use_cache: bool = False, + **kwargs, + ): + response = self._call_model( + self.model.create_chat_completion, + use_cache=use_cache, + **kwargs, + ) + response_text = response["choices"][0]["message"]["content"] + + usage = ModelUsage(**response["usage"]) + return response_text, usage + + +class OpenAiChat(ChatModel, LLMModel): """Queries an OpenAI model using its API.""" def __init__( @@ -288,34 +330,19 @@ class OpenAI(LLMModel): def _create_completion( self, - chat: bool, use_cache: bool = False, **kwargs, ): - if chat: - response = self._call_model( - self.openai_client.chat.completions.create, - model=self.model_name, - use_cache=use_cache, - **kwargs, - ) - response_text = response.choices[0].message.content - else: - response = self._call_model( - self.openai_client.completions.create, - model=self.model, - use_cache=use_cache, - **kwargs, - ) - response_text = response.choices[0].text - + response = self._call_model( + self.openai_client.chat.completions.create, + model=self.model_name, + use_cache=use_cache, + **kwargs, + ) + response_text = response.choices[0].message.content usage = ModelUsage(**response.usage.__dict__) return response_text, usage - @property - def model_cache_key(self): - return self.model_name - @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("OpenAI model arguments") @@ -339,4 +366,3 @@ argument_group.add_argument( type=int, help="Maximum number of tokens being generated from LLM. ", ) -argument_group.add_argument("--chat", "-c", action="store_true") diff --git a/main.py b/main.py index 3ed613293ecbe5b0deb4c7e18205c749ed16a79a..e0b8faedc0485ebce7a65ac35969faa971580a74 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from evoprompt.cli import argument_parser from evoprompt.evolution import get_optimizer_class -from evoprompt.models import Llama, LLMModel +from evoprompt.models import Llama, LlamaChat, LLMModel from evoprompt.task import get_task from evoprompt.utils import init_rng, setup_console_logger @@ -61,7 +61,7 @@ if __name__ == "__main__": if debug: logger.info("DEBUG mode: Do a quick run") - # set up evolution model + # # set up evolution model evolution_model = LLMModel.get_model(options.evolution_engine, options=options) match options.evolution_engine: @@ -76,10 +76,12 @@ if __name__ == "__main__": logger.info("Using Llama as the evaluation engine") evaluation_model: LLMModel match options.evolution_engine: - case "llama": + case "llama" | "llamachat": evaluation_model = evolution_model case "openai": evaluation_model = Llama(options) + case "openaichat": + evaluation_model = LlamaChat(options) task = get_task(options.task, evaluation_model, **options.__dict__) logger.info(f"Running with task {task.__class__.__name__}")