From 4ecbf23189986743fad238e9692e35b1777074bb Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 16 Jan 2024 13:52:35 +0100 Subject: [PATCH] Add DE --- main.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index afa23ec..4344fbe 100644 --- a/main.py +++ b/main.py @@ -146,9 +146,8 @@ Prompt 3: {prompt3} Basic Prompt: {basic_prompt} """ - -@log_calls("Performing prompt evolution") -def evolution(prompt1: str, prompt2: str): +@log_calls("Performing prompt evolution using GA") +def evolution_ga(prompt1: str, prompt2: str): # Following the evolutionary operators in GA, a new candidate prompt is generated through # a two-step process based on the selected two parents: # 1) The parent prompts undergo crossover, resulting in a new prompt that @@ -172,6 +171,29 @@ def evolution(prompt1: str, prompt2: str): return evolved_prompt +@log_calls("Performing prompt evolution using DE") +def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str): + # TODO add comment from paper + evolved_prompt = llm( + messages=[ + { + "role": "user", + "content": DE_PROMPT.format( + prompt1=prompt1, + prompt2=prompt2, + prompt3=best_prompt, + basic_prompt=basic_prompt, + ), + } + ], + max_tokens=None, + stop="</prompt>", + ) + if "<prompt>" in evolved_prompt: + evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] + return evolved_prompt + + @log_calls("Updating prompts") def update(prompts: list[str], scores: list[float], N: int): # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt @@ -195,7 +217,7 @@ def update(prompts: list[str], scores: list[float], N: int): return retained_prompts, retained_scores -def run_episode(): +def run_episode(evo_alg_str: str): # Algorithm 1 Discrete prompt optimization: EVOPROMPT # Require: @@ -225,23 +247,51 @@ def run_episode(): for t in trange(1, T + 1, desc="T", leave=True): # Line 3: Selection: select a certain number of prompts from current population as parent prompts # pr1,...,prk ∼ Pt−1 - for _ in trange(N, desc="N", leave=False): + if evo_alg_str == "de": + # DE needs best prompt for evolution + best_prompt_current_evolution = max(range(N), key=lambda i: S[t - 1][i]) + for i in trange(N, desc="N", leave=False): + new_evolutions = [] + new_evolutions_scores = [] + # for both GA and DE we start with two parent prompts pr1, pr2 = selection(P[t - 1], S[t - 1]) # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators # p′i â†Evo(pr1,...,prk) - p_i = evolution(pr1, pr2) - family_tree[p_i] = (pr1, pr2) + if evo_alg_str == "ga": + p_i = evolution_ga(pr1, pr2) + elif evo_alg_str == "de": + p_i = evolution_de(pr1, pr2, P[t - 1][i], best_prompt_current_evolution) + + family_tree[p_i] = pr1, pr2 # Line 5: Evaluation # s′_i ↠f(p′i,D) s_i = f_D(p_i) - P[t - 1].append(p_i) - S[t - 1].append(s_i) + if evo_alg_str == "ga": + # for GA we consider all evolutions and select best afterward + new_evolutions.append(p_i) + new_evolutions_scores.append(s_i) + elif evo_alg_str == "de": + # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise + if s_i > S[t - 1][i]: + new_evolutions.append(p_i) + new_evolutions_scores.append(s_i) + else: + new_evolutions.append(P[t - 1][i]) + new_evolutions_scores.append(S[t - 1][i]) # Line 6: Update based on the evaluation scores # Pt ↠{Pt−1, p′i} and St ↠{St−1, s′i} - P_t, S_t = update(P[t - 1], S[t - 1], N) + if evo_alg_str == "ga": + # GA keeps N best prompts from current population and evolutions + P_t, S_t = update( + P[t - 1] + new_evolutions, S[t - 1] + new_evolutions_scores, N + ) + elif evo_alg_str == "de": + # DE only keeps evolved or basic prompt in each step, whichever is better + P_t, S_t = new_evolutions, new_evolutions_scores + P.append(P_t) S.append(S_t) @@ -262,8 +312,6 @@ def run_episode(): family_tree = {} if __name__ == "__main__": options = argument_parser.parse_args() - if options.evolution_algorithm == "de": - raise NotImplementedError("DE is not implemented yet") match options.evolution_engine: case "llama2": logger.info("Using Llama2 client as the evolution engine") @@ -285,4 +333,4 @@ if __name__ == "__main__": .message.content ) - run_episode() + run_episode(evo_alg_str=options.evolution_algorithm) -- GitLab