From ec6c2f39989775238e171b4bee2cbf26075cc46c Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Mon, 18 Mar 2024 15:02:41 +0100
Subject: [PATCH] Fix llama2 prompt format

---
 main.py   | 24 +++++++++++-------------
 models.py | 36 ++++++++++++++++++++++--------------
 task.py   | 16 +++++-----------
 3 files changed, 38 insertions(+), 38 deletions(-)

diff --git a/main.py b/main.py
index d4cd0e6..66676ca 100644
--- a/main.py
+++ b/main.py
@@ -25,15 +25,7 @@ def conv2bool(_str: Any):
 
 load_dotenv()
 
-PARAPHRASE_PROMPT = """
-Below is an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.
-
-### Instruction:
-{instruction}
-
-### Response:
-<prompt>
-"""
+PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>."""
 
 
 @log_calls("Paraphrasing prompts")
@@ -42,9 +34,14 @@ def paraphrase_prompts(prompt: str, n: int):
     paraphrases = []
     for _ in range(n):
         paraphrase, usage = evolution_model(
-            prompt=PARAPHRASE_PROMPT.format(instruction=prompt)
+            system_message=PARAPHRASE_PROMPT,
+            prompt=prompt,
+            prompt_prefix=" Instruction: \"",
+            prompt_suffix="\"",
         )
         total_usage += usage
+        if "<prompt>" in paraphrase:
+            paraphrase = paraphrase.split("<prompt>")[1].split("</prompt>")[0]
         paraphrases.append(paraphrase)
     return paraphrases, usage
 
@@ -67,9 +64,9 @@ def selection(prompts):
         selection_probabilities = [score / sum(scores) for score in scores]
     return choice(prompts, size=2, replace=False, p=selection_probabilities)
 
+SYSTEM_MESSAGE = "Please follow the instruction step-by-step to generate a better prompt."
 
 GA_PROMPT = """
-Please follow the instruction step-by-step to generate a better prompt.
 1. Cross over the following prompts and generate a new prompt:
 Prompt 1: {prompt1}
 Prompt 2: {prompt2}
@@ -78,7 +75,6 @@ Prompt 2: {prompt2}
 
 
 DE_PROMPT = """
-Please follow the instruction step-by-step to generate a better prompt.
 1. Identify the different parts between the Prompt 1 and Prompt 2:
 Prompt 1: {prompt1}
 Prompt 2: {prompt2}
@@ -101,6 +97,7 @@ def evolution_ga(prompt1: str, prompt2: str):
     # Based on this two-step process, we design instructions, guiding LLMs to
     # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
     evolved_prompt, usage = evolution_model(
+        system_message=SYSTEM_MESSAGE,
         prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2)
     )
     if "<prompt>" in evolved_prompt:
@@ -112,6 +109,7 @@ def evolution_ga(prompt1: str, prompt2: str):
 def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str):
     # TODO add comment from paper
     evolved_prompt, usage = evolution_model(
+        system_message=SYSTEM_MESSAGE,
         prompt=DE_PROMPT.format(
             prompt1=prompt1,
             prompt2=prompt2,
@@ -323,7 +321,7 @@ if __name__ == "__main__":
         debug = conv2bool(os.getenv("EP_DEBUG", False))
         if debug is None:
             raise ValueError(
-                f"{os.getenv('EP_DEBUG')} is not allowed for env variable EP_DEBUG."
+                f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG."
             )
 
     match options.task:
diff --git a/models.py b/models.py
index b58eed5..053223d 100644
--- a/models.py
+++ b/models.py
@@ -46,7 +46,11 @@ class Llama2(LLMModel):
 
     def __call__(
         self,
+        system_message: str,
         prompt: str,
+        prompt_appendix: str = "",
+        prompt_prefix: str = "",
+        prompt_suffix: str = "",
         chat: bool | None = None,
         stop: str = "</prompt>",
         max_tokens: int = 200,
@@ -57,11 +61,14 @@ class Llama2(LLMModel):
 
         if chat:
             response = self.model.create_chat_completion(
-                # TODO add system message?
                 messages=[
+                    {
+                        "role": "system",
+                        "content": system_message,
+                    },
                     {
                         "role": "user",
-                        "content": prompt,
+                        "content": prompt + prompt_appendix,
                     }
                 ],
                 stop=stop,
@@ -73,7 +80,7 @@ class Llama2(LLMModel):
             return response["choices"][0]["message"]["content"], usage
         else:
             response = self.model.create_completion(
-                prompt=prompt, stop=stop, max_tokens=max_tokens, **kwargs
+                prompt=system_message + prompt_prefix + prompt + prompt_suffix + prompt_appendix, stop=stop, max_tokens=max_tokens, **kwargs
             )
             usage = ModelUsage(**response["usage"])
             self.usage += usage
@@ -104,16 +111,17 @@ class OpenAI(LLMModel):
 
         if chat:
             response = self.openai_client.chat.completions.create(
-                model=self.model,
-                messages=[
-                    {
-                        "role": "user",
-                        "content": prompt,
-                    }
-                ],
-                stop=stop,
-                max_tokens=max_tokens,
-                **kwargs,
+                    model=self.model,
+                    messages=[
+                        # TODO consider system message
+                        {
+                            "role": "user",
+                            "content": prompt,
+                        }
+                    ],
+                    stop=stop,
+                    max_tokens=max_tokens,
+                    **kwargs,
             )
             usage = ModelUsage(**response.usage.__dict__)
             self.usage += usage
@@ -128,4 +136,4 @@ class OpenAI(LLMModel):
             )
             usage = ModelUsage(**response.usage.__dict__)
             self.usage += usage
-            return response.choices[0].message.content, usage
+            return response.choices[0].text, usage
diff --git a/task.py b/task.py
index 332ffb8..05dea6d 100644
--- a/task.py
+++ b/task.py
@@ -13,15 +13,7 @@ from models import Llama2, OpenAI
 from utils import ModelUsage, log_calls, logger
 
 CLASSIFICATION_PROMPT = """
-Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
-### Instruction:
-{instruction}.
-
-### Input:
-{input}
-
-### Response:
+You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 """
 
 
@@ -111,7 +103,9 @@ class SentimentAnalysis(Task):
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
         response, usage = self.model(
-            prompt=CLASSIFICATION_PROMPT.format(instruction=prompt, input=text),
+            system_message=CLASSIFICATION_PROMPT,
+            prompt=prompt,
+            prompt_appendix="\nInput: " + "\"" + text + "\"",
             grammar=sa_grammar_fn() if self.use_grammar else None,
             chat=False if self.use_grammar else True,
         )
@@ -164,7 +158,7 @@ class SentimentAnalysis(Task):
     @property
     def base_prompt(self):
         #  from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
-        return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as "’positive’" if the sentiment of the sentence is positive or as "’negative’" if the sentiment of the sentence is negative. Return label only without any other text."""
+        return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text."""
 
 
 def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'):
-- 
GitLab