Skip to content
Snippets Groups Projects
Commit 7b9808b6 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

Merge branch 'remove-is-chat' into 'refactor-models'

remove is_chat argument

See merge request !2
parents 691ced52 23d528a7
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
import functools
import hashlib
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Callable, ClassVar
import warnings
import llama_cpp
import openai
......@@ -19,10 +20,11 @@ logger = logging.getLogger(__name__)
logging.captureWarnings(True)
warnings.simplefilter("once")
ChatMessages = list[dict[str, str]]
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
chat: bool
def __init_subclass__(cls) -> None:
if inspect.isabstract(cls):
......@@ -36,26 +38,25 @@ class LLMModel(ABC):
raise ValueError("Model %s does not exist", name)
return cls.models[name](options=options, **kwargs)
@functools.lru_cache
def _compute_cache_key(self, name, **kwargs):
# we use a tuple of the model name, the options, and the kwargs as the cache key
return (name,) + tuple((key, value) for key, value in kwargs.items())
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
self.chat = options.chat
# store kwargs for caching
self.options = options
self.kwargs = kwargs
# set up caching for model calls
self._call_model_cached = None
if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key))
self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])(
self._call_model_cached
)
@cache.memoize(typed=True, ignore=[0, "func"])
def _call_function(func, *args, **kwargs):
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod
def create_completion(
self,
system_message: str | None,
......@@ -65,48 +66,10 @@ class LLMModel(ABC):
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
chat: bool | None = None,
stop: str = None,
history: dict = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
max_tokens = kwargs.pop("max_tokens", self.options.max_tokens)
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
if not chat and system_message:
prompt = system_message + prompt
messages = [self._get_user_message(prompt)]
if chat:
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
model_input = {"messages": messages}
else:
model_input = {"prompt": prompt}
reponse, usage = self._create_completion(
chat=chat,
**model_input,
stop=stop,
max_tokens=max_tokens,
use_cache=use_cache,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
) -> tuple[str, ModelUsage]: ...
def _get_user_message(self, content: str):
return {
......@@ -134,18 +97,26 @@ class LLMModel(ABC):
warnings.warn("Caching is disabled when a grammar is provided.")
use_cache = False
if use_cache:
# use cached function call
cache_key = self._compute_cache_key(
model_completion_fn.__name__, **self.options.__dict__, **self.kwargs
)
return self._call_model_cached(model_completion_fn, cache_key, **kwargs)
if use_cache and self._call_model_cached is not None:
return self._call_model_cached(model_completion_fn, **kwargs)
else:
return model_completion_fn(**kwargs)
def _call_model_cached(self, func, cache_key, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property
def model_cache_key(self):
unique_options_key = json.dumps(
vars(self.options),
sort_keys=True,
) + json.dumps(
self.kwargs,
sort_keys=True,
)
cache_key = (
str(self.model_name).replace("/", "_")
+ "/"
+ hashlib.sha1(unique_options_key.encode()).hexdigest()
)
return cache_key
@classmethod
@abstractmethod
......@@ -205,34 +176,51 @@ class Llama(LLMModel):
# needs to be called after model is initialized
super().__init__(options=options, n_ctx=n_ctx, **kwargs)
def create_completion(
self,
system_message: str | None,
prompt: str,
*,
use_cache: bool = False,
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
if system_message is not None:
prompt = system_message + prompt
reponse, usage = self._create_completion(
prompt=prompt,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, messages, usage
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
else:
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
usage = ModelUsage(**response["usage"])
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls.__name__} model arguments")
......@@ -272,7 +260,61 @@ class Llama(LLMModel):
)
class OpenAI(LLMModel):
class ChatModel:
def create_completion(
self,
system_message: str | None,
prompt: str,
*,
use_cache: bool = False,
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
stop: str = None,
history: ChatMessages | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is None and system_message:
history = [self._get_system_message(system_message)]
reponse, usage = self._create_completion(
messages=messages,
stop=stop,
use_cache=use_cache,
max_tokens=self.options.max_tokens,
**kwargs,
)
messages.append(self._get_assistant_message(reponse))
return reponse, history + messages, usage
class LlamaChat(ChatModel, Llama):
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
usage = ModelUsage(**response["usage"])
return response_text, usage
class OpenAiChat(ChatModel, LLMModel):
"""Queries an OpenAI model using its API."""
def __init__(
......@@ -288,34 +330,19 @@ class OpenAI(LLMModel):
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
else:
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
@property
def model_cache_key(self):
return self.model_name
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
......@@ -339,4 +366,3 @@ argument_group.add_argument(
type=int,
help="Maximum number of tokens being generated from LLM. ",
)
argument_group.add_argument("--chat", "-c", action="store_true")
......@@ -6,7 +6,7 @@ from dotenv import load_dotenv
from evoprompt.cli import argument_parser
from evoprompt.evolution import get_optimizer_class
from evoprompt.models import Llama, LLMModel
from evoprompt.models import Llama, LlamaChat, LLMModel
from evoprompt.task import get_task
from evoprompt.utils import init_rng, setup_console_logger
......@@ -61,7 +61,7 @@ if __name__ == "__main__":
if debug:
logger.info("DEBUG mode: Do a quick run")
# set up evolution model
# # set up evolution model
evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
match options.evolution_engine:
......@@ -76,10 +76,12 @@ if __name__ == "__main__":
logger.info("Using Llama as the evaluation engine")
evaluation_model: LLMModel
match options.evolution_engine:
case "llama":
case "llama" | "llamachat":
evaluation_model = evolution_model
case "openai":
evaluation_model = Llama(options)
case "openaichat":
evaluation_model = LlamaChat(options)
task = get_task(options.task, evaluation_model, **options.__dict__)
logger.info(f"Running with task {task.__class__.__name__}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment