Skip to content
Snippets Groups Projects
Commit 2df18f7f authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

remove is_chat argument

parent 691ced52
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
import functools
import inspect
import logging
import warnings
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Callable, ClassVar
import warnings
import llama_cpp
import openai
......@@ -22,7 +22,6 @@ warnings.simplefilter("once")
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
chat: bool
def __init_subclass__(cls) -> None:
if inspect.isabstract(cls):
......@@ -43,7 +42,6 @@ class LLMModel(ABC):
def __init__(self, options: Namespace, **kwargs):
self.usage = ModelUsage()
self.chat = options.chat
# store kwargs for caching
self.options = options
......@@ -56,6 +54,16 @@ class LLMModel(ABC):
self._call_model_cached
)
@abstractmethod
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
pass
def create_completion(
self,
system_message: str | None,
......@@ -65,42 +73,18 @@ class LLMModel(ABC):
prompt_appendix: str = "",
prompt_prefix: str = "",
prompt_suffix: str = "",
chat: bool | None = None,
stop: str = None,
history: dict = None,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
max_tokens = kwargs.pop("max_tokens", self.options.max_tokens)
) -> tuple[str, list[dict[str, str]], ModelUsage]:
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
if not chat and system_message:
prompt = system_message + prompt
messages = [self._get_user_message(prompt)]
if chat:
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
model_input = {"messages": messages}
else:
model_input = {"prompt": prompt}
model_input = self.build_model_input(prompt, system_message, messages, history)
reponse, usage = self._create_completion(
chat=chat,
**model_input,
stop=stop,
max_tokens=max_tokens,
use_cache=use_cache,
**kwargs,
)
......@@ -137,7 +121,7 @@ class LLMModel(ABC):
if use_cache:
# use cached function call
cache_key = self._compute_cache_key(
model_completion_fn.__name__, **self.options.__dict__, **self.kwargs
self.__class__.__name__, **self.options.__dict__, **self.kwargs
)
return self._call_model_cached(model_completion_fn, cache_key, **kwargs)
else:
......@@ -205,26 +189,29 @@ class Llama(LLMModel):
# needs to be called after model is initialized
super().__init__(options=options, n_ctx=n_ctx, **kwargs)
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
if system_message is not None:
prompt = system_message + prompt
return {"prompt": prompt}
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
else:
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
response = self._call_model(
self.model.create_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["text"]
usage = ModelUsage(**response["usage"])
return response_text, usage
......@@ -272,6 +259,43 @@ class Llama(LLMModel):
)
class LlamaChat(Llama):
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.model.create_chat_completion,
use_cache=use_cache,
**kwargs,
)
response_text = response["choices"][0]["message"]["content"]
usage = ModelUsage(**response["usage"])
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
[messages.insert(index, entry) for index, entry in enumerate(history)]
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
return {"messages": messages}
class OpenAI(LLMModel):
"""Queries an OpenAI model using its API."""
......@@ -288,26 +312,16 @@ class OpenAI(LLMModel):
def _create_completion(
self,
chat: bool,
use_cache: bool = False,
**kwargs,
):
if chat:
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
else:
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
......@@ -322,6 +336,24 @@ class OpenAI(LLMModel):
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
def OpenAiChat(OpenAI):
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.openai_client.chat.completions.create,
model=self.model_name,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].message.content
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
"--evolution-engine",
......
......@@ -6,7 +6,7 @@ from dotenv import load_dotenv
from evoprompt.cli import argument_parser
from evoprompt.evolution import get_optimizer_class
from evoprompt.models import Llama, LLMModel
from evoprompt.models import Llama, LlamaChat, LLMModel
from evoprompt.task import get_task
from evoprompt.utils import init_rng, setup_console_logger
......@@ -62,7 +62,12 @@ if __name__ == "__main__":
logger.info("DEBUG mode: Do a quick run")
# set up evolution model
evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
evolution_model_name = (
(options.evolution_engine + "chat")
if options.chat
else options.evolution_engine
)
evolution_model = LLMModel.get_model(evolution_model_name, options=options)
match options.evolution_engine:
case "llama":
......@@ -79,7 +84,10 @@ if __name__ == "__main__":
case "llama":
evaluation_model = evolution_model
case "openai":
evaluation_model = Llama(options)
if not options.chat:
evaluation_model = Llama(options)
else:
evaluation_model = LlamaChat(options)
task = get_task(options.task, evaluation_model, **options.__dict__)
logger.info(f"Running with task {task.__class__.__name__}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment