From 2df18f7fd25dc3a31e7479ceb1febd4c899f498c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Fri, 16 Aug 2024 17:39:53 +0200 Subject: [PATCH] remove is_chat argument --- evoprompt/models.py | 158 ++++++++++++++++++++++++++------------------ main.py | 14 +++- 2 files changed, 106 insertions(+), 66 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 3131fff..c646803 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -1,11 +1,11 @@ import functools import inspect import logging +import warnings from abc import ABC, abstractmethod from argparse import ArgumentParser, Namespace from pathlib import Path from typing import Any, Callable, ClassVar -import warnings import llama_cpp import openai @@ -22,7 +22,6 @@ warnings.simplefilter("once") class LLMModel(ABC): models: ClassVar[dict[str, type["LLMModel"]]] = {} - chat: bool def __init_subclass__(cls) -> None: if inspect.isabstract(cls): @@ -43,7 +42,6 @@ class LLMModel(ABC): def __init__(self, options: Namespace, **kwargs): self.usage = ModelUsage() - self.chat = options.chat # store kwargs for caching self.options = options @@ -56,6 +54,16 @@ class LLMModel(ABC): self._call_model_cached ) + @abstractmethod + def build_model_input( + self, + prompt: str, + system_message: str | None, + messages: list[dict[str, str]], + history: list[dict[str, str]] | None = None, + ): + pass + def create_completion( self, system_message: str | None, @@ -65,42 +73,18 @@ class LLMModel(ABC): prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", - chat: bool | None = None, stop: str = None, - history: dict = None, + history: list[dict[str, str]] | None = None, **kwargs: Any, - ) -> tuple[str, ModelUsage]: - if chat is None: - chat = self.chat - max_tokens = kwargs.pop("max_tokens", self.options.max_tokens) - + ) -> tuple[str, list[dict[str, str]], ModelUsage]: # create prompt prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix - - if not chat and system_message: - prompt = system_message + prompt - messages = [self._get_user_message(prompt)] - - if chat: - # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case - # TODO is it better to check for a system message in the history? - if history is not None: - messages = history + messages - elif system_message: - messages.insert( - 0, - self._get_system_message(system_message), - ) - model_input = {"messages": messages} - else: - model_input = {"prompt": prompt} + model_input = self.build_model_input(prompt, system_message, messages, history) reponse, usage = self._create_completion( - chat=chat, **model_input, stop=stop, - max_tokens=max_tokens, use_cache=use_cache, **kwargs, ) @@ -137,7 +121,7 @@ class LLMModel(ABC): if use_cache: # use cached function call cache_key = self._compute_cache_key( - model_completion_fn.__name__, **self.options.__dict__, **self.kwargs + self.__class__.__name__, **self.options.__dict__, **self.kwargs ) return self._call_model_cached(model_completion_fn, cache_key, **kwargs) else: @@ -205,26 +189,29 @@ class Llama(LLMModel): # needs to be called after model is initialized super().__init__(options=options, n_ctx=n_ctx, **kwargs) + def build_model_input( + self, + prompt: str, + system_message: str | None, + messages: list[dict[str, str]], + history: list[dict[str, str]] | None = None, + ): + + if system_message is not None: + prompt = system_message + prompt + return {"prompt": prompt} + def _create_completion( self, - chat: bool, use_cache: bool = False, **kwargs, ): - if chat: - response = self._call_model( - self.model.create_chat_completion, - use_cache=use_cache, - **kwargs, - ) - response_text = response["choices"][0]["message"]["content"] - else: - response = self._call_model( - self.model.create_completion, - use_cache=use_cache, - **kwargs, - ) - response_text = response["choices"][0]["text"] + response = self._call_model( + self.model.create_completion, + use_cache=use_cache, + **kwargs, + ) + response_text = response["choices"][0]["text"] usage = ModelUsage(**response["usage"]) return response_text, usage @@ -272,6 +259,43 @@ class Llama(LLMModel): ) +class LlamaChat(Llama): + + def _create_completion( + self, + use_cache: bool = False, + **kwargs, + ): + response = self._call_model( + self.model.create_chat_completion, + use_cache=use_cache, + **kwargs, + ) + response_text = response["choices"][0]["message"]["content"] + + usage = ModelUsage(**response["usage"]) + return response_text, usage + + def build_model_input( + self, + prompt: str, + system_message: str | None, + messages: list[dict[str, str]], + history: list[dict[str, str]] | None = None, + ): + + # a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case + # TODO is it better to check for a system message in the history? + if history is not None: + [messages.insert(index, entry) for index, entry in enumerate(history)] + elif system_message: + messages.insert( + 0, + self._get_system_message(system_message), + ) + return {"messages": messages} + + class OpenAI(LLMModel): """Queries an OpenAI model using its API.""" @@ -288,26 +312,16 @@ class OpenAI(LLMModel): def _create_completion( self, - chat: bool, use_cache: bool = False, **kwargs, ): - if chat: - response = self._call_model( - self.openai_client.chat.completions.create, - model=self.model_name, - use_cache=use_cache, - **kwargs, - ) - response_text = response.choices[0].message.content - else: - response = self._call_model( - self.openai_client.completions.create, - model=self.model, - use_cache=use_cache, - **kwargs, - ) - response_text = response.choices[0].text + response = self._call_model( + self.openai_client.completions.create, + model=self.model, + use_cache=use_cache, + **kwargs, + ) + response_text = response.choices[0].text usage = ModelUsage(**response.usage.__dict__) return response_text, usage @@ -322,6 +336,24 @@ class OpenAI(LLMModel): group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") +def OpenAiChat(OpenAI): + + def _create_completion( + self, + use_cache: bool = False, + **kwargs, + ): + response = self._call_model( + self.openai_client.chat.completions.create, + model=self.model_name, + use_cache=use_cache, + **kwargs, + ) + response_text = response.choices[0].message.content + usage = ModelUsage(**response.usage.__dict__) + return response_text, usage + + argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument( "--evolution-engine", diff --git a/main.py b/main.py index 3ed6132..90bac4b 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from evoprompt.cli import argument_parser from evoprompt.evolution import get_optimizer_class -from evoprompt.models import Llama, LLMModel +from evoprompt.models import Llama, LlamaChat, LLMModel from evoprompt.task import get_task from evoprompt.utils import init_rng, setup_console_logger @@ -62,7 +62,12 @@ if __name__ == "__main__": logger.info("DEBUG mode: Do a quick run") # set up evolution model - evolution_model = LLMModel.get_model(options.evolution_engine, options=options) + evolution_model_name = ( + (options.evolution_engine + "chat") + if options.chat + else options.evolution_engine + ) + evolution_model = LLMModel.get_model(evolution_model_name, options=options) match options.evolution_engine: case "llama": @@ -79,7 +84,10 @@ if __name__ == "__main__": case "llama": evaluation_model = evolution_model case "openai": - evaluation_model = Llama(options) + if not options.chat: + evaluation_model = Llama(options) + else: + evaluation_model = LlamaChat(options) task = get_task(options.task, evaluation_model, **options.__dict__) logger.info(f"Running with task {task.__class__.__name__}") -- GitLab