From 356db033d9c4938ca30193b7d5deaf5a32de2f67 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 19 Mar 2024 12:31:32 +0100 Subject: [PATCH] Add system message for OpenAI model --- main.py | 13 ++++--- models.py | 112 ++++++++++++++++++++++++++++++++++++++---------------- task.py | 43 +++++++++------------ utils.py | 3 +- 4 files changed, 107 insertions(+), 64 deletions(-) diff --git a/main.py b/main.py index 66676ca..5524016 100644 --- a/main.py +++ b/main.py @@ -36,8 +36,8 @@ def paraphrase_prompts(prompt: str, n: int): paraphrase, usage = evolution_model( system_message=PARAPHRASE_PROMPT, prompt=prompt, - prompt_prefix=" Instruction: \"", - prompt_suffix="\"", + prompt_prefix=' Instruction: "', + prompt_suffix='"', ) total_usage += usage if "<prompt>" in paraphrase: @@ -64,7 +64,10 @@ def selection(prompts): selection_probabilities = [score / sum(scores) for score in scores] return choice(prompts, size=2, replace=False, p=selection_probabilities) -SYSTEM_MESSAGE = "Please follow the instruction step-by-step to generate a better prompt." + +SYSTEM_MESSAGE = ( + "Please follow the instruction step-by-step to generate a better prompt." +) GA_PROMPT = """ 1. Cross over the following prompts and generate a new prompt: @@ -98,7 +101,7 @@ def evolution_ga(prompt1: str, prompt2: str): # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. evolved_prompt, usage = evolution_model( system_message=SYSTEM_MESSAGE, - prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2) + prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2), ) if "<prompt>" in evolved_prompt: evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] @@ -115,7 +118,7 @@ def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str prompt2=prompt2, prompt3=best_prompt, basic_prompt=basic_prompt, - ) + ), ) if "<prompt>" in evolved_prompt: evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] diff --git a/models.py b/models.py index 053223d..eaa8f20 100644 --- a/models.py +++ b/models.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from pathlib import Path from typing import Any @@ -18,6 +19,22 @@ class LLMModel: self.chat = chat self.model = model + @abstractmethod + def __call__( + self, + system_message: str | None, + prompt: str, + *, + prompt_appendix: str, + prompt_prefix: str, + prompt_suffix: str, + chat: bool | None, + stop: str, + max_tokens: int, + **kwargs: Any, + ) -> Any: + pass + class Llama2(LLMModel): """Loads and queries a Llama2 model.""" @@ -30,8 +47,9 @@ class Llama2(LLMModel): n_threads: int = 8, n_ctx: int = 4096, verbose: bool = False, - **kwargs + **kwargs, ) -> None: + # initialize model model = Llama( model_path, @@ -46,45 +64,58 @@ class Llama2(LLMModel): def __call__( self, - system_message: str, + system_message: str | None, prompt: str, + *, prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, - **kwargs: Any + **kwargs: Any, ) -> tuple[str, ModelUsage]: if chat is None: chat = self.chat if chat: - response = self.model.create_chat_completion( - messages=[ + messages = [ + { + "role": "user", + "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, + } + ] + if system_message: + messages.insert( + 0, { "role": "system", "content": system_message, }, - { - "role": "user", - "content": prompt + prompt_appendix, - } - ], + ) + response = self.model.create_chat_completion( + messages=messages, stop=stop, max_tokens=max_tokens, **kwargs, ) - usage = ModelUsage(**response["usage"]) - self.usage += usage - return response["choices"][0]["message"]["content"], usage + response_text = response["choices"][0]["message"]["content"] else: response = self.model.create_completion( - prompt=system_message + prompt_prefix + prompt + prompt_suffix + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs + prompt=(system_message if system_message else "") + + prompt_prefix + + prompt + + prompt_suffix + + prompt_appendix, + stop=stop, + max_tokens=max_tokens, + **kwargs, ) - usage = ModelUsage(**response["usage"]) - self.usage += usage - return response["choices"][0]["text"], usage + response_text = response["choices"][0]["text"] + # input(f"Response: {response_text}") + usage = ModelUsage(**response["usage"]) + self.usage += usage + return response_text, usage class OpenAI(LLMModel): @@ -95,33 +126,46 @@ class OpenAI(LLMModel): ) -> None: super().__init__(chat, model) - # initialize model + # initialize client for API calls self.openai_client = openai.OpenAI(**kwargs) def __call__( self, + system_message: str | None, prompt: str, - chat: bool = None, + *, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, - **kwargs: Any + **kwargs: Any, ) -> tuple[str, ModelUsage]: if chat is None: chat = self.chat if chat: + messages = [ + { + "role": "user", + "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, + } + ] + if system_message: + messages.insert( + 0, + { + "role": "system", + "content": system_message, + }, + ) response = self.openai_client.chat.completions.create( - model=self.model, - messages=[ - # TODO consider system message - { - "role": "user", - "content": prompt, - } - ], - stop=stop, - max_tokens=max_tokens, - **kwargs, + model=self.model, + messages=messages, + stop=stop, + max_tokens=max_tokens, + **kwargs, ) usage = ModelUsage(**response.usage.__dict__) self.usage += usage @@ -129,7 +173,11 @@ class OpenAI(LLMModel): else: response = self.openai_client.completions.create( model=self.model, - prompt=prompt, + prompt=(system_message if system_message else "") + + prompt_prefix + + prompt + + prompt_suffix + + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs, diff --git a/task.py b/task.py index 05dea6d..529a62f 100644 --- a/task.py +++ b/task.py @@ -12,7 +12,7 @@ from tqdm import tqdm from models import Llama2, OpenAI from utils import ModelUsage, log_calls, logger -CLASSIFICATION_PROMPT = """ +SYSTEM_MESSAGE = """ You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. """ @@ -103,11 +103,10 @@ class SentimentAnalysis(Task): # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, usage = self.model( - system_message=CLASSIFICATION_PROMPT, + system_message=SYSTEM_MESSAGE, prompt=prompt, - prompt_appendix="\nInput: " + "\"" + text + "\"", + prompt_appendix="\nInput: " + '"' + text + '"', grammar=sa_grammar_fn() if self.use_grammar else None, - chat=False if self.use_grammar else True, ) if not self.use_grammar: @@ -207,20 +206,6 @@ class QuestionAnswering(Task): validation_split: str | None = None, test_split: str | None = None, ) -> None: - self.evaluation_prompt = """ - Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - - ### Instruction: - {instruction} - - ### Context: - {context} - - ### Question: - {question} - - ### Response: - """ self.metric = load_metric("squad") @@ -249,13 +234,17 @@ class QuestionAnswering(Task): ) response, usage = self.model( - prompt=self.evaluation_prompt.format( - instruction=prompt, - context=context, - question=question, - ), + system_message=SYSTEM_MESSAGE, + prompt=prompt, + prompt_appendix="\nContext: " + + '"' + + context + + '"' + + "\nQuestion: " + + '"' + + question + + '"', grammar=grammar, - chat=False if self.use_grammar else True, ) if not self.use_grammar: @@ -289,7 +278,9 @@ class QuestionAnswering(Task): prompt, context=datum["context"], question=datum["question"], - ).lower() + ) + # TODO check if answer is lower-cased in metric computation + evaluation_usage += usage num_samples += 1 @@ -313,4 +304,4 @@ class QuestionAnswering(Task): @property def base_prompt(self): # TODO find good prompt - return """In this task, you are given contexts with questions. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" diff --git a/utils.py b/utils.py index b1c9918..fd3cd67 100644 --- a/utils.py +++ b/utils.py @@ -23,7 +23,7 @@ Only return the name without any text before or after.""".strip() def initialize_run_directory(model: OpenAI | Llama2): - response, usage = model(run_name_prompt, chat=True) + response, usage = model(None, run_name_prompt, chat=True) model.usage -= usage run_name_match = re.search(r"^\w+$", response, re.MULTILINE) if run_name_match is None: @@ -31,6 +31,7 @@ def initialize_run_directory(model: OpenAI | Llama2): else: run_name = run_name_match.group(0) run_directory = current_directory / f"runs/{run_name}" + # TODO what if name exists? run_directory.mkdir(parents=True, exist_ok=False) file_handler = logging.FileHandler(run_directory / "output.log") file_handler.setLevel(logging.DEBUG) -- GitLab