diff --git a/main.py b/main.py index 66676ca157e2577a9a9ac6ad15439a65989f5d93..5524016966d5ba565f06e06131024660e2e0b0b3 100644 --- a/main.py +++ b/main.py @@ -36,8 +36,8 @@ def paraphrase_prompts(prompt: str, n: int): paraphrase, usage = evolution_model( system_message=PARAPHRASE_PROMPT, prompt=prompt, - prompt_prefix=" Instruction: \"", - prompt_suffix="\"", + prompt_prefix=' Instruction: "', + prompt_suffix='"', ) total_usage += usage if "<prompt>" in paraphrase: @@ -64,7 +64,10 @@ def selection(prompts): selection_probabilities = [score / sum(scores) for score in scores] return choice(prompts, size=2, replace=False, p=selection_probabilities) -SYSTEM_MESSAGE = "Please follow the instruction step-by-step to generate a better prompt." + +SYSTEM_MESSAGE = ( + "Please follow the instruction step-by-step to generate a better prompt." +) GA_PROMPT = """ 1. Cross over the following prompts and generate a new prompt: @@ -98,7 +101,7 @@ def evolution_ga(prompt1: str, prompt2: str): # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. evolved_prompt, usage = evolution_model( system_message=SYSTEM_MESSAGE, - prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2) + prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2), ) if "<prompt>" in evolved_prompt: evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] @@ -115,7 +118,7 @@ def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str prompt2=prompt2, prompt3=best_prompt, basic_prompt=basic_prompt, - ) + ), ) if "<prompt>" in evolved_prompt: evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] diff --git a/models.py b/models.py index 053223d57c97be56c07945a3719ebd2c03a75264..eaa8f208c27837df9e1dcf544d5f505001c37997 100644 --- a/models.py +++ b/models.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from pathlib import Path from typing import Any @@ -18,6 +19,22 @@ class LLMModel: self.chat = chat self.model = model + @abstractmethod + def __call__( + self, + system_message: str | None, + prompt: str, + *, + prompt_appendix: str, + prompt_prefix: str, + prompt_suffix: str, + chat: bool | None, + stop: str, + max_tokens: int, + **kwargs: Any, + ) -> Any: + pass + class Llama2(LLMModel): """Loads and queries a Llama2 model.""" @@ -30,8 +47,9 @@ class Llama2(LLMModel): n_threads: int = 8, n_ctx: int = 4096, verbose: bool = False, - **kwargs + **kwargs, ) -> None: + # initialize model model = Llama( model_path, @@ -46,45 +64,58 @@ class Llama2(LLMModel): def __call__( self, - system_message: str, + system_message: str | None, prompt: str, + *, prompt_appendix: str = "", prompt_prefix: str = "", prompt_suffix: str = "", chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, - **kwargs: Any + **kwargs: Any, ) -> tuple[str, ModelUsage]: if chat is None: chat = self.chat if chat: - response = self.model.create_chat_completion( - messages=[ + messages = [ + { + "role": "user", + "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, + } + ] + if system_message: + messages.insert( + 0, { "role": "system", "content": system_message, }, - { - "role": "user", - "content": prompt + prompt_appendix, - } - ], + ) + response = self.model.create_chat_completion( + messages=messages, stop=stop, max_tokens=max_tokens, **kwargs, ) - usage = ModelUsage(**response["usage"]) - self.usage += usage - return response["choices"][0]["message"]["content"], usage + response_text = response["choices"][0]["message"]["content"] else: response = self.model.create_completion( - prompt=system_message + prompt_prefix + prompt + prompt_suffix + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs + prompt=(system_message if system_message else "") + + prompt_prefix + + prompt + + prompt_suffix + + prompt_appendix, + stop=stop, + max_tokens=max_tokens, + **kwargs, ) - usage = ModelUsage(**response["usage"]) - self.usage += usage - return response["choices"][0]["text"], usage + response_text = response["choices"][0]["text"] + # input(f"Response: {response_text}") + usage = ModelUsage(**response["usage"]) + self.usage += usage + return response_text, usage class OpenAI(LLMModel): @@ -95,33 +126,46 @@ class OpenAI(LLMModel): ) -> None: super().__init__(chat, model) - # initialize model + # initialize client for API calls self.openai_client = openai.OpenAI(**kwargs) def __call__( self, + system_message: str | None, prompt: str, - chat: bool = None, + *, + prompt_appendix: str = "", + prompt_prefix: str = "", + prompt_suffix: str = "", + chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, - **kwargs: Any + **kwargs: Any, ) -> tuple[str, ModelUsage]: if chat is None: chat = self.chat if chat: + messages = [ + { + "role": "user", + "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix, + } + ] + if system_message: + messages.insert( + 0, + { + "role": "system", + "content": system_message, + }, + ) response = self.openai_client.chat.completions.create( - model=self.model, - messages=[ - # TODO consider system message - { - "role": "user", - "content": prompt, - } - ], - stop=stop, - max_tokens=max_tokens, - **kwargs, + model=self.model, + messages=messages, + stop=stop, + max_tokens=max_tokens, + **kwargs, ) usage = ModelUsage(**response.usage.__dict__) self.usage += usage @@ -129,7 +173,11 @@ class OpenAI(LLMModel): else: response = self.openai_client.completions.create( model=self.model, - prompt=prompt, + prompt=(system_message if system_message else "") + + prompt_prefix + + prompt + + prompt_suffix + + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs, diff --git a/task.py b/task.py index 05dea6d432b129fe00c4e95afe6dce195e137e7f..529a62f21bc083d91ca87fa629b22f110494260b 100644 --- a/task.py +++ b/task.py @@ -12,7 +12,7 @@ from tqdm import tqdm from models import Llama2, OpenAI from utils import ModelUsage, log_calls, logger -CLASSIFICATION_PROMPT = """ +SYSTEM_MESSAGE = """ You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. """ @@ -103,11 +103,10 @@ class SentimentAnalysis(Task): # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, usage = self.model( - system_message=CLASSIFICATION_PROMPT, + system_message=SYSTEM_MESSAGE, prompt=prompt, - prompt_appendix="\nInput: " + "\"" + text + "\"", + prompt_appendix="\nInput: " + '"' + text + '"', grammar=sa_grammar_fn() if self.use_grammar else None, - chat=False if self.use_grammar else True, ) if not self.use_grammar: @@ -207,20 +206,6 @@ class QuestionAnswering(Task): validation_split: str | None = None, test_split: str | None = None, ) -> None: - self.evaluation_prompt = """ - Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - - ### Instruction: - {instruction} - - ### Context: - {context} - - ### Question: - {question} - - ### Response: - """ self.metric = load_metric("squad") @@ -249,13 +234,17 @@ class QuestionAnswering(Task): ) response, usage = self.model( - prompt=self.evaluation_prompt.format( - instruction=prompt, - context=context, - question=question, - ), + system_message=SYSTEM_MESSAGE, + prompt=prompt, + prompt_appendix="\nContext: " + + '"' + + context + + '"' + + "\nQuestion: " + + '"' + + question + + '"', grammar=grammar, - chat=False if self.use_grammar else True, ) if not self.use_grammar: @@ -289,7 +278,9 @@ class QuestionAnswering(Task): prompt, context=datum["context"], question=datum["question"], - ).lower() + ) + # TODO check if answer is lower-cased in metric computation + evaluation_usage += usage num_samples += 1 @@ -313,4 +304,4 @@ class QuestionAnswering(Task): @property def base_prompt(self): # TODO find good prompt - return """In this task, you are given contexts with questions. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" diff --git a/utils.py b/utils.py index b1c9918167985c39ae9c4f791f0f8df4183d40ef..fd3cd67dfad2679c39b671161eaa2fa86c6d810a 100644 --- a/utils.py +++ b/utils.py @@ -23,7 +23,7 @@ Only return the name without any text before or after.""".strip() def initialize_run_directory(model: OpenAI | Llama2): - response, usage = model(run_name_prompt, chat=True) + response, usage = model(None, run_name_prompt, chat=True) model.usage -= usage run_name_match = re.search(r"^\w+$", response, re.MULTILINE) if run_name_match is None: @@ -31,6 +31,7 @@ def initialize_run_directory(model: OpenAI | Llama2): else: run_name = run_name_match.group(0) run_directory = current_directory / f"runs/{run_name}" + # TODO what if name exists? run_directory.mkdir(parents=True, exist_ok=False) file_handler = logging.FileHandler(run_directory / "output.log") file_handler.setLevel(logging.DEBUG)