diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index 3891015aafbb4f86e26f5bac9487a1cd0e02df98..ae909f2aa6c8f7cc56c20b330cb80d0f808288ea 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -180,7 +180,7 @@ class GeneticAlgorithm(EvolutionAlgorithm): # Based on this two-step process, we design instructions, guiding LLMs to # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. - evolved_prompt, _, usage = self.evolution_model( + evolved_prompt, _, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, prompt=GA_PROMPT.format(prompt1=prompt_1, prompt2=prompt_2), ) @@ -240,7 +240,7 @@ class DifferentialEvolution(EvolutionAlgorithm): prompts_current_evolution, key=lambda prompt: prompt.score ) - evolved_prompt, _, usage = self.evolution_model( + evolved_prompt, _, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, prompt=get_de_prompt_template(self.use_evolution_demo, self.task).format( prompt1=prompt_1, @@ -315,7 +315,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): messages = None for idx, prompt in enumerate(DE_COT_PROMPTS): - response, messages, usage = self.evolution_model( + response, messages, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, prompt=prompt.format( prompt1=prompt_1, diff --git a/evoprompt/models.py b/evoprompt/models.py index 6ed1b3576be101c9b07e8ede923f7bc8460d89a4..f9221698f99656d8236a07c8e577cacfe7a483a2 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,37 +1,30 @@ +import hashlib import inspect +import json import logging +import warnings from abc import ABC, abstractmethod from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar -import joblib import llama_cpp import openai +from diskcache import Cache from evoprompt.cli import argument_parser from evoprompt.opt_types import ModelUsage from evoprompt.utils import get_seed logger = logging.getLogger(__name__) +logging.captureWarnings(True) +warnings.simplefilter("once") - -current_directory = Path(__file__).resolve().parent - -mem = joblib.Memory(location=".cache_dir", verbose=0) - - -@mem.cache -def get_model_completion(model, chat, *args, **kwargs): - if chat: - return model.create_chat_completion(*args, **kwargs) - else: - return model.create_completion(*args, **kwargs) +ChatMessages = list[dict[str, str]] class LLMModel(ABC): models: ClassVar[dict[str, type["LLMModel"]]] = {} - chat: bool def __init_subclass__(cls) -> None: if inspect.isabstract(cls): @@ -40,65 +33,90 @@ class LLMModel(ABC): cls.register_arguments(argument_parser) @classmethod - def get_model(cls, name: str, options: Namespace): + def get_model(cls, name: str, options: Namespace, **kwargs): if name not in cls.models: raise ValueError("Model %s does not exist", name) - return cls.models[name](options) + return cls.models[name](options=options, **kwargs) - def __init__(self, options: Namespace): + def __init__(self, options: Namespace, **kwargs): self.usage = ModelUsage() - self.chat = options.chat - def create_completion( - self, - system_message: str | None, - prompt: str, - *, - prompt_appendix: str = "", - prompt_prefix: str = "", - prompt_suffix: str = "", - chat: bool | None = None, - stop: str = None, - max_tokens: int = None, - history: dict = None, - **kwargs: Any, - ) -> tuple[str, ModelUsage]: - return self._create_completion( - system_message=system_message, - prompt=prompt, - prompt_appendix=prompt_appendix, - prompt_prefix=prompt_prefix, - prompt_suffix=prompt_suffix, - chat=chat, - stop=stop, - max_tokens=max_tokens, - history=history, - **kwargs, - ) + # store kwargs for caching + self.options = options + self.kwargs = kwargs + + # set up caching for model calls + self._call_model_cached = None + if not options.disable_cache: + cache = Cache(Path(".cache_dir", self.model_cache_key)) + + @cache.memoize(typed=True, ignore=[0, "func"]) + def _call_function(func, *args, **kwargs): + return func(*args, **kwargs) + + self._call_model_cached = _call_function @abstractmethod - def _create_completion( + def create_completion( self, system_message: str | None, prompt: str, *, + use_cache: bool = False, prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", - chat: bool | None = None, stop: str = None, - max_tokens: int = None, - history: dict = None, + history: ChatMessages | None = None, **kwargs: Any, - ) -> tuple[str, ModelUsage]: - pass - - def __call__( - self, - *args, - **kwargs, - ) -> tuple[str, ModelUsage]: - return self.create_completion(*args, **kwargs) + ) -> tuple[str, ModelUsage]: ... + + def _get_user_message(self, content: str): + return { + "role": "user", + "content": content, + } + + def _get_system_message(self, content: str): + return { + "role": "system", + "content": content, + } + + def _get_assistant_message(self, content: str): + return { + "role": "assistant", + "content": content, + } + + def _call_model( + self, model_completion_fn: Callable, use_cache: bool = False, **kwargs + ): + if "grammar" in kwargs and kwargs["grammar"] is not None: + # grammar cannot be pickled therefore we cannot use caching when a grammar is provided + warnings.warn("Caching is disabled when a grammar is provided.") + use_cache = False + + if use_cache and self._call_model_cached is not None: + return self._call_model_cached(model_completion_fn, **kwargs) + else: + return model_completion_fn(**kwargs) + + @property + def model_cache_key(self): + unique_options_key = json.dumps( + vars(self.options), + sort_keys=True, + ) + json.dumps( + self.kwargs, + sort_keys=True, + ) + cache_key = ( + str(self.model_name).replace("/", "_") + + "/" + + hashlib.sha1(unique_options_key.encode()).hexdigest() + ) + return cache_key @classmethod @abstractmethod @@ -116,25 +134,26 @@ class Llama(LLMModel): n_ctx: int = 4096, **kwargs, ) -> None: - - super().__init__(options) - # initialize model + add_kwargs = {} seed = get_seed() if seed is not None: - kwargs["seed"] = seed + add_kwargs["seed"] = seed + if options.llama_path is not None: # use local file self.model = llama_cpp.Llama( - str(options.llama_path), + model_path=options.llama_path, chat_format=options.chat_format, chat_handler=options.chat_handler, verbose=options.verbose > 1 or options.llama_verbose, n_gpu_layers=n_gpu_layers, n_threads=n_threads, n_ctx=n_ctx, + **add_kwargs, **kwargs, ) + self.model_name = Path(options.llama_path).stem else: # use pre-trained model from HF hub self.model = llama_cpp.Llama.from_pretrained( @@ -146,93 +165,61 @@ class Llama(LLMModel): n_gpu_layers=n_gpu_layers, n_threads=n_threads, n_ctx=n_ctx, - # max_tokens=2, + **add_kwargs, **kwargs, ) + self.model_name = Path( + options.llama_model, options.llama_model_file + ).with_suffix("") - def _create_completion( + # pass all arguments to super constructor which should be taken into account for caching + # needs to be called after model is initialized + super().__init__(options=options, n_ctx=n_ctx, **kwargs) + + def create_completion( self, system_message: str | None, prompt: str, *, + use_cache: bool = False, prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", - chat: bool | None = None, stop: str = None, - max_tokens: int = None, - history: dict = None, + history: ChatMessages | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: - if chat is None: - chat = self.chat - - if chat: - messages = [ - { - "role": "user", - "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, - } - ] - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message? - if history is not None: - messages = history + messages - elif system_message: - messages.insert( - 0, - { - "role": "system", - "content": system_message, - }, - ) - if "grammar" in kwargs and kwargs["grammar"] is not None: - # grammer cannot be pickled therefore we cannot use caching when a grammar is provided - # logger.warnning - model_fn = get_model_completion.func - else: - # use cached function - model_fn = get_model_completion - - response = model_fn( - self.model, - chat, - messages=messages, - stop=stop, - max_tokens=max_tokens, - **kwargs, - ) - response_text = response["choices"][0]["message"]["content"] - else: - prompt = ( - (system_message if system_message else "") - + prompt_prefix - + prompt - + prompt_suffix - + prompt_appendix - ) - response = self.model.create_completion( - prompt=prompt, - stop=stop, - max_tokens=max_tokens, - **kwargs, - ) - response_text = response["choices"][0]["text"] - messages = [ - { - "role": "user", - "content": prompt, - } - ] - messages.append( - { - "role": "assistant", - "content": response_text, - } + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] + if system_message is not None: + prompt = system_message + prompt + + reponse, usage = self._create_completion( + prompt=prompt, + stop=stop, + use_cache=use_cache, + max_tokens=self.options.max_tokens, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, messages, usage + + def _create_completion( + self, + use_cache: bool = False, + **kwargs, + ): + response = self._call_model( + self.model.create_completion, + use_cache=use_cache, + **kwargs, ) - # input(f"Response: {response_text}") + response_text = response["choices"][0]["text"] + usage = ModelUsage(**response["usage"]) - return response_text, messages, usage + return response_text, usage @classmethod def register_arguments(cls, parser: ArgumentParser): @@ -245,14 +232,14 @@ class Llama(LLMModel): group.add_argument( "--llama-model", type=str, - default="TheBloke/Llama-2-13B-chat-GGUF", + default="QuantFactory/Meta-Llama-3-8B-Instruct-GGUF", help="A pre-trained model from HF hub", ), group.add_argument( "--llama-model-file", type=str, # TODO provide some help for selecting model files, and point user to set this argument if needed - default="llama-2-13b-chat.Q5_K_M.gguf", + default="Meta-Llama-3-8B-Instruct.Q5_K_M.gguf", help="Specify the model file in case of a pre-trained model from HF hub, e.g., a specific quantized version", ), group.add_argument( @@ -273,75 +260,88 @@ class Llama(LLMModel): ) -class OpenAI(LLMModel): +class ChatModel: + + def create_completion( + self, + system_message: str | None, + prompt: str, + *, + use_cache: bool = False, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + stop: str = None, + history: ChatMessages | None = None, + **kwargs: Any, + ) -> tuple[str, ModelUsage]: + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] + + # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case + # TODO is it better to check for a system message in the history? + if history is None and system_message: + history = [self._get_system_message(system_message)] + + reponse, usage = self._create_completion( + messages=messages, + stop=stop, + use_cache=use_cache, + max_tokens=self.options.max_tokens, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, history + messages, usage + + +class LlamaChat(ChatModel, Llama): + + def _create_completion( + self, + use_cache: bool = False, + **kwargs, + ): + response = self._call_model( + self.model.create_chat_completion, + use_cache=use_cache, + **kwargs, + ) + response_text = response["choices"][0]["message"]["content"] + + usage = ModelUsage(**response["usage"]) + return response_text, usage + + +class OpenAiChat(ChatModel, LLMModel): """Queries an OpenAI model using its API.""" def __init__( self, options: Namespace, - verbose: bool = False, **kwargs, ) -> None: - self.model_name = options.openai_model - super().__init__(options) - # initialize client for API calls self.openai_client = openai.OpenAI(**kwargs) + self.model_name = options.openai_model + + super().__init__(options, **kwargs) - def __call__( + def _create_completion( self, - system_message: str | None, - prompt: str, - *, - prompt_appendix: str = "", - prompt_prefix: str = "", - prompt_suffix: str = "", - chat: bool | None = None, - stop: str = "</prompt>", - max_tokens: int = None, - **kwargs: Any, - ) -> tuple[str, ModelUsage]: - if chat is None: - chat = self.chat - - if chat: - messages = [ - { - "role": "user", - "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, - } - ] - if system_message: - messages.insert( - 0, - { - "role": "system", - "content": system_message, - }, - ) - response = self.openai_client.chat.completions.create( - model=self.model_name, - messages=messages, - stop=stop, - max_tokens=max_tokens, - **kwargs, - ) - usage = ModelUsage(**response.usage.__dict__) - return response.choices[0].message.content, usage - else: - response = self.openai_client.completions.create( - model=self.model, - prompt=(system_message if system_message else "") - + prompt_prefix - + prompt - + prompt_suffix - + prompt_appendix, - stop=stop, - max_tokens=max_tokens, - **kwargs, - ) - usage = ModelUsage(**response.usage.__dict__) - return response.choices[0].text, usage + use_cache: bool = False, + **kwargs, + ): + response = self._call_model( + self.openai_client.chat.completions.create, + model=self.model_name, + use_cache=use_cache, + **kwargs, + ) + response_text = response.choices[0].message.content + usage = ModelUsage(**response.usage.__dict__) + return response_text, usage @classmethod def register_arguments(cls, parser: ArgumentParser): @@ -357,4 +357,12 @@ argument_group.add_argument( choices=LLMModel.models.keys(), default="llama", ) -argument_group.add_argument("--chat", "-c", action="store_true") +argument_group.add_argument( + "--disable-cache", + action="store_true", +) +argument_group.add_argument( + "--max-tokens", + type=int, + help="Maximum number of tokens being generated from LLM. ", +) diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index e3ed91b99b4466801e424c082ab09b8202adac8b..c43330e5957ec33e391240dd1d0acec9dd3bf4f5 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -34,7 +34,7 @@ def paraphrase_prompts( if num_tries >= max_tries: break num_tries += 1 - paraphrase, _, usage = model( + paraphrase, _, usage = model.create_completion( system_message=PARAPHRASE_PROMPT, prompt=prompt, prompt_prefix=' Instruction: "', diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 4348cbb2cec3fa97c4f6e3077722ae22d2ad18e0..8b27328b2139fc50716f132f8b26acbabeed8139 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -312,12 +312,14 @@ class Task(metaclass=ABCMeta): def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? - response, _, usage = self.model( + response, _, usage = self.model.create_completion( system_message=SYSTEM_MESSAGE, prompt=prompt, prompt_appendix=self._get_prompt_text_for_datum(datum), # grammar can be applied to constrain the model output grammar=self._get_grammar(datum) if self.use_grammar else None, + # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result + use_cache=True, ) if not self.use_grammar: diff --git a/evoprompt/utils.py b/evoprompt/utils.py index c6b7803bfe3e8e314f6178e1dbf5f95ee5ebb21a..7cb3e1b1ae1fd98ae9ecd002e640bdf71f44337f 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -75,7 +75,7 @@ def initialize_run_directory(model: Callable): if file_handler is not None: logger.removeHandler(file_handler) - response, _, _ = model(None, run_name_prompt) + response, _, _ = model.create_completion(None, run_name_prompt) run_name_match = re.search(r"^\w+$", response, re.MULTILINE) existing_run_names = os.listdir(RUNS_DIR) if RUNS_DIR.exists() else [] if run_name_match is None or run_name_match.group(0) in existing_run_names: diff --git a/main.py b/main.py index 90330dc9253298bf2f770da7851ccef6f82b76d0..e0b8faedc0485ebce7a65ac35969faa971580a74 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from evoprompt.cli import argument_parser from evoprompt.evolution import get_optimizer_class -from evoprompt.models import Llama, LLMModel +from evoprompt.models import Llama, LlamaChat, LLMModel from evoprompt.task import get_task from evoprompt.utils import init_rng, setup_console_logger @@ -61,8 +61,8 @@ if __name__ == "__main__": if debug: logger.info("DEBUG mode: Do a quick run") - # set up evolution model - evolution_model = LLMModel.get_model(options.evolution_engine, options) + # # set up evolution model + evolution_model = LLMModel.get_model(options.evolution_engine, options=options) match options.evolution_engine: case "llama": @@ -76,10 +76,12 @@ if __name__ == "__main__": logger.info("Using Llama as the evaluation engine") evaluation_model: LLMModel match options.evolution_engine: - case "llama": + case "llama" | "llamachat": evaluation_model = evolution_model case "openai": evaluation_model = Llama(options) + case "openaichat": + evaluation_model = LlamaChat(options) task = get_task(options.task, evaluation_model, **options.__dict__) logger.info(f"Running with task {task.__class__.__name__}")