from abc import abstractmethod from argparse import ArgumentParser, Namespace from pathlib import Path from typing import Any import openai from llama_cpp import Llama from cli import argument_parser from opt_types import ModelUsage current_directory = Path(__file__).resolve().parent class LLMModel: chat: bool def __init__(self, options: Namespace): self.usage = ModelUsage() self.chat = options.chat @abstractmethod def __call__( self, system_message: str | None, prompt: str, *, prompt_appendix: str, prompt_prefix: str, prompt_suffix: str, chat: bool | None, stop: str, max_tokens: int, **kwargs: Any, ) -> Any: pass @classmethod @abstractmethod def register_arguments(cls, parser: ArgumentParser): pass class Llama2(LLMModel): """Loads and queries a Llama2 model.""" def __init__( self, options: Namespace, n_gpu_layers: int = 60, n_threads: int = 8, n_ctx: int = 4096, verbose: bool = False, **kwargs, ) -> None: # initialize model self.model = Llama( options.llama_path, chat_format="llama-2", verbose=verbose, n_gpu_layers=n_gpu_layers, n_threads=n_threads, n_ctx=n_ctx, **kwargs, ) super().__init__(options) def __call__( self, system_message: str | None, prompt: str, *, prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, **kwargs: Any, ) -> tuple[str, ModelUsage]: if chat is None: chat = self.chat if chat: messages = [ { "role": "user", "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, } ] if system_message: messages.insert( 0, { "role": "system", "content": system_message, }, ) response = self.model.create_chat_completion( messages=messages, stop=stop, max_tokens=max_tokens, **kwargs, ) response_text = response["choices"][0]["message"]["content"] else: response = self.model.create_completion( prompt=(system_message if system_message else "") + prompt_prefix + prompt + prompt_suffix + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs, ) response_text = response["choices"][0]["text"] # input(f"Response: {response_text}") usage = ModelUsage(**response["usage"]) self.usage += usage return response_text, usage @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("Llama2 model arguments") group.add_argument( "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf" ) class OpenAI(LLMModel): """Queries an OpenAI model using its API.""" def __init__( self, options: Namespace, verbose: bool = False, **kwargs, ) -> None: self.model_name = options.openai_model super().__init__(options) # initialize client for API calls self.openai_client = openai.OpenAI(**kwargs) def __call__( self, system_message: str | None, prompt: str, *, prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, **kwargs: Any, ) -> tuple[str, ModelUsage]: if chat is None: chat = self.chat if chat: messages = [ { "role": "user", "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, } ] if system_message: messages.insert( 0, { "role": "system", "content": system_message, }, ) response = self.openai_client.chat.completions.create( model=self.model_name, messages=messages, stop=stop, max_tokens=max_tokens, **kwargs, ) usage = ModelUsage(**response.usage.__dict__) self.usage += usage return response.choices[0].message.content, usage else: response = self.openai_client.completions.create( model=self.model, prompt=(system_message if system_message else "") + prompt_prefix + prompt + prompt_suffix + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs, ) usage = ModelUsage(**response.usage.__dict__) self.usage += usage return response.choices[0].text, usage @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("OpenAI model arguments") group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") models = {model.__name__.lower(): model for model in LLMModel.__subclasses__()} for name, model in models.items(): model.register_arguments(argument_parser) def get_model(name: str, options: Namespace): if name not in models: raise ValueError("Model %s does not exist", name) return models[name](options) argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument( "--evolution-engine", "-e", type=str, choices=models.keys(), default="llama2" ) argument_group.add_argument("--chat", "-c", action="store_true")