From 4ecbf23189986743fad238e9692e35b1777074bb Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Tue, 16 Jan 2024 13:52:35 +0100
Subject: [PATCH] Add DE

---
 main.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 61 insertions(+), 13 deletions(-)

diff --git a/main.py b/main.py
index afa23ec..4344fbe 100644
--- a/main.py
+++ b/main.py
@@ -146,9 +146,8 @@ Prompt 3: {prompt3}
 Basic Prompt: {basic_prompt}
 """
 
-
-@log_calls("Performing prompt evolution")
-def evolution(prompt1: str, prompt2: str):
+@log_calls("Performing prompt evolution using GA")
+def evolution_ga(prompt1: str, prompt2: str):
     # Following the evolutionary operators in GA, a new candidate prompt is generated through
     # a two-step process based on the selected two parents:
     # 1) The parent prompts undergo crossover, resulting in a new prompt that
@@ -172,6 +171,29 @@ def evolution(prompt1: str, prompt2: str):
     return evolved_prompt
 
 
+@log_calls("Performing prompt evolution using DE")
+def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str):
+    # TODO add comment from paper
+    evolved_prompt = llm(
+        messages=[
+            {
+                "role": "user",
+                "content": DE_PROMPT.format(
+                    prompt1=prompt1,
+                    prompt2=prompt2,
+                    prompt3=best_prompt,
+                    basic_prompt=basic_prompt,
+                ),
+            }
+        ],
+        max_tokens=None,
+        stop="</prompt>",
+    )
+    if "<prompt>" in evolved_prompt:
+        evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
+    return evolved_prompt
+
+
 @log_calls("Updating prompts")
 def update(prompts: list[str], scores: list[float], N: int):
     # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt
@@ -195,7 +217,7 @@ def update(prompts: list[str], scores: list[float], N: int):
     return retained_prompts, retained_scores
 
 
-def run_episode():
+def run_episode(evo_alg_str: str):
     # Algorithm 1 Discrete prompt optimization: EVOPROMPT
 
     # Require:
@@ -225,23 +247,51 @@ def run_episode():
     for t in trange(1, T + 1, desc="T", leave=True):
         # Line 3: Selection: select a certain number of prompts from current population as parent prompts
         # pr1,...,prk ∼ Pt−1
-        for _ in trange(N, desc="N", leave=False):
+        if evo_alg_str == "de":
+            # DE needs best prompt for evolution
+            best_prompt_current_evolution = max(range(N), key=lambda i: S[t - 1][i])
+        for i in trange(N, desc="N", leave=False):
+            new_evolutions = []
+            new_evolutions_scores = []
+            # for both GA and DE we start with two parent prompts
             pr1, pr2 = selection(P[t - 1], S[t - 1])
 
             # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
             # p′i ←Evo(pr1,...,prk)
-            p_i = evolution(pr1, pr2)
-            family_tree[p_i] = (pr1, pr2)
+            if evo_alg_str == "ga":
+                p_i = evolution_ga(pr1, pr2)
+            elif evo_alg_str == "de":
+                p_i = evolution_de(pr1, pr2, P[t - 1][i], best_prompt_current_evolution)
+
+            family_tree[p_i] = pr1, pr2
 
             # Line 5: Evaluation
             # s′_i ← f(p′i,D)
             s_i = f_D(p_i)
-            P[t - 1].append(p_i)
-            S[t - 1].append(s_i)
+            if evo_alg_str == "ga":
+                # for GA we consider all evolutions and select best afterward
+                new_evolutions.append(p_i)
+                new_evolutions_scores.append(s_i)
+            elif evo_alg_str == "de":
+                # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
+                if s_i > S[t - 1][i]:
+                    new_evolutions.append(p_i)
+                    new_evolutions_scores.append(s_i)
+                else:
+                    new_evolutions.append(P[t - 1][i])
+                    new_evolutions_scores.append(S[t - 1][i])
 
         # Line 6: Update based on the evaluation scores
         # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
-        P_t, S_t = update(P[t - 1], S[t - 1], N)
+        if evo_alg_str == "ga":
+            # GA keeps N best prompts from current population and evolutions
+            P_t, S_t = update(
+                P[t - 1] + new_evolutions, S[t - 1] + new_evolutions_scores, N
+            )
+        elif evo_alg_str == "de":
+            # DE only keeps evolved or basic prompt in each step, whichever is better
+            P_t, S_t = new_evolutions, new_evolutions_scores
+
         P.append(P_t)
         S.append(S_t)
 
@@ -262,8 +312,6 @@ def run_episode():
 family_tree = {}
 if __name__ == "__main__":
     options = argument_parser.parse_args()
-    if options.evolution_algorithm == "de":
-        raise NotImplementedError("DE is not implemented yet")
     match options.evolution_engine:
         case "llama2":
             logger.info("Using Llama2 client as the evolution engine")
@@ -285,4 +333,4 @@ if __name__ == "__main__":
                 .message.content
             )
 
-    run_episode()
+    run_episode(evo_alg_str=options.evolution_algorithm)
-- 
GitLab