From efd17f2f7fcd4c537019d44ff779148544a99997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Thu, 7 Mar 2024 10:42:10 +0100 Subject: [PATCH] clean up datamodel to avoid duplication and improve interpretability --- main.py | 106 +++++++++++++++++++++++++------------------------------ task.py | 7 ++-- utils.py | 66 ++-------------------------------- 3 files changed, 53 insertions(+), 126 deletions(-) diff --git a/main.py b/main.py index b16e646..b65ab4a 100644 --- a/main.py +++ b/main.py @@ -60,7 +60,7 @@ def paraphrase_prompts(prompt: str, n: int): @log_calls("Performing selection") -def selection(prompts, scores): +def selection(prompts): # In GA, two parent solutions are normally selected based on the roulette wheel # selection method according to the fitness value (Lipowski & Lipowska, 2012). # Similar to this, we utilize the roulette wheel selection method to select @@ -69,6 +69,7 @@ def selection(prompts, scores): # development set of the i-th prompt in the population, which contains a total # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as # pi = si / Σj=1->N sj. + scores = [prompt.score for prompt in prompts] if sum(scores) == 0: # sum of scores is 0 ==> each score is 0, draw with equal probability selection_probabilities = len(scores) * [1 / len(scores)] @@ -134,7 +135,7 @@ def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str @log_calls("Updating prompts") -def update(prompts: list[str], scores: list[float], N: int): +def update(prompts: list[str], N: int): # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt # using a development set, denoted as D, to obtain a score that quantifies the # quality of the prompt. We consider a straightforward selection strategy. @@ -142,18 +143,16 @@ def update(prompts: list[str], scores: list[float], N: int): # which are combined with the current population of N prompts. # The updated population is then selected by retaining the N prompts with the highest scores. retained_prompts = [] - retained_scores = [] - - for prompt, score in zip(prompts, scores): + min_retained_score = 0 + for prompt in prompts: if len(retained_prompts) < N: retained_prompts.append(prompt) - retained_scores.append(score) - elif score > min(retained_scores): - min_index = retained_scores.index(min(retained_scores)) - retained_prompts[min_index] = prompt - retained_scores[min_index] = score + min_retained_score = min(min_retained_score, prompt.score) + elif prompt.score > min_retained_score: + retained_prompts.sort(key=lambda p: p.score) + retained_prompts[0] = prompt - return retained_prompts, retained_scores + return retained_prompts def run_episode(evo_alg_str: str, debug: bool = False): @@ -178,49 +177,55 @@ def run_episode(evo_alg_str: str, debug: bool = False): # Line 1: Initial evaluation scores: S0 ↠{si = fD (pi )|i ∈ [1, N ]} # the current population's scores population_scores = [f_D(p) for p in initial_population] - # S keeps track of scores - S = [population_scores] - # P keeps track of prompts and its generations + # all_prompts contains a list of Prompt objects that took part in this run at some time # converting prompts to Prompt object - P = [ - [ + all_prompts: dict[str, Prompt] = { + prompt.id: prompt + for prompt in [ Prompt(p, score=score, gen=0) - for idx, (p, score) in enumerate(zip(initial_population, population_scores)) + for (p, score) in zip(initial_population, population_scores) ] - ] + } + + # P keeps track of prompts in each generation + P = [[prompt_id for prompt_id in all_prompts.keys()]] # add initial prompts to family tree # None marks that there is no parent family_tree: dict[str, tuple[str, str] | None] = { - prompt.id: None for prompt in P[0] + prompt_id: None for prompt_id in P[0] } - # evolution = EvolutionGA(num_evolutions=N) - # Line 2: for t in trange(1, T + 1, desc="T", leave=True): # Line 3: Selection: select a certain number of prompts from current population as parent prompts # pr1,...,prk ∼ Pt−1 + prompts_current_evolution = [all_prompts[prompt_id] for prompt_id in P[t - 1]] if evo_alg_str == "de": # DE needs best prompt for evolution - best_prompt_current_evolution = P[t - 1][ - max(range(N), key=lambda i: S[t - 1][i]) - ] + best_prompt_current_evolution = max( + prompts_current_evolution, key=lambda prompt: prompt.score + ) new_evolutions = [] new_evolutions_scores = [] for i in trange(N, desc="N", leave=False): # for both GA and DE we start with two parent prompts - pr1, pr2 = selection(P[t - 1], S[t - 1]) + pr1, pr2 = selection([all_prompts[prompt_id] for prompt_id in P[t - 1]]) # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators # p′i â†Evo(pr1,...,prk) if evo_alg_str == "ga": p_i = evolution_ga(pr1, pr2) elif evo_alg_str == "de": - p_i = evolution_de(pr1, pr2, P[t - 1][i], best_prompt_current_evolution) + p_i = evolution_de( + pr1, + pr2, + prompts_current_evolution[i], + best_prompt_current_evolution, + ) # Line 5: Evaluation # s′_i ↠f(p′i,D) @@ -232,52 +237,37 @@ def run_episode(evo_alg_str: str, debug: bool = False): family_tree[evolved_prompt.id] = (pr1.id, pr2.id) new_evolutions.append(evolved_prompt) - new_evolutions_scores.append(s_i) - + all_prompts |= {prompt.id: prompt for prompt in new_evolutions} # Line 6: Update based on the evaluation scores # Pt ↠{Pt−1, p′i} and St ↠{St−1, s′i} if evo_alg_str == "ga": # GA keeps N best prompts from current population and evolutions - population, population_scores = update( - new_evolutions + P[t - 1], new_evolutions_scores + S[t - 1], N - ) + population = update(new_evolutions + prompts_current_evolution, N) elif evo_alg_str == "de": # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise - assert ( - len(P[t - 1]) - == len(S[t - 1]) - == len(new_evolutions) - == len(new_evolutions_scores) - ) - population, population_scores = list( - zip( - *[ - ( - (new_prompt, new_prompt_score) - if new_prompt_score > current_prompt_score - else (current_prompt, current_prompt_score) - ) - for current_prompt, current_prompt_score, new_prompt, new_prompt_score in zip( - P[t - 1], S[t - 1], new_evolutions, new_evolutions_scores - ) - ] + assert len(prompts_current_evolution) == len(new_evolutions) + population = [ + ( + new_prompt + if new_prompt.score > current_prompt.score + else current_prompt ) - ) + for current_prompt, new_prompt in zip( + prompts_current_evolution, new_evolutions + ) + ] # store new generation - P.append(population) - S.append(population_scores) + P.append([prompt.id for prompt in population]) - save_snapshot(family_tree, P, S, T, N) - save_genealogy(family_tree, P, metric_name=task.metric_name) - save_family_tree_visualization(family_tree, P, metric_name=task.metric_name) + save_snapshot(all_prompts, family_tree, P, T, N) # Line 8: Return the best prompt, p∗, among the final population PT : # p∗ ↠argmaxp∈PT f(p, D) - p = max(range(N), key=lambda i: S[-1][i]) - logger.info(f"Best prompt: {P[-1][p]}") + p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)] + logger.info(f"Best prompt: {p}") # We pick the prompt with the highest score on the development set and report its score on the testset. - test_performance = task.evaluate_test(P[-1][p]) + test_performance = task.evaluate_test(p.content) logger.info(f"Best prompt on test set: {test_performance}") diff --git a/task.py b/task.py index c1ba91a..1a18848 100644 --- a/task.py +++ b/task.py @@ -1,18 +1,17 @@ +import re from abc import abstractmethod from collections import defaultdict from functools import lru_cache -import re from typing import DefaultDict, Mapping, Union from datasets import Dataset, load_dataset from evaluate import load as load_metric -from tqdm import tqdm -from models import Llama2, OpenAI from llama_cpp import LlamaGrammar +from tqdm import tqdm +from models import Llama2, OpenAI from utils import log_calls, logger - CLASSIFICATION_PROMPT = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. diff --git a/utils.py b/utils.py index 7548aa1..17f97fb 100644 --- a/utils.py +++ b/utils.py @@ -124,75 +124,13 @@ class PromptEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) -@log_calls("Saving genealogy") -def save_genealogy(family_tree, P, metric_name: str): - node_id = lambda prompt: hash(prompt) - dot = Digraph(comment="Genealogy") - for t in range(len(P)): - with dot.subgraph(name=f"cluster_{t}") as cluster: - cluster.attr( - label=f"Generation {t}", - color="lightgrey", - ) - for prompt, parents in family_tree: - if prompt.gen != t: - continue - cluster.node( - f"{node_id(prompt)}", - label=f"{fill(str(prompt))}\n{metric_name}: {prompt.score}", - ) - if parents is not None: - parent_1, parent_2 = parents - dot.edge(f"{node_id(parent_1)}", f"{node_id(prompt)}") - dot.edge(f"{node_id(parent_2)}", f"{node_id(prompt)}") - - dot.render(run_directory / "genealogy", engine="dot") - - -def save_family_tree_visualization(family_tree, P, metric_name: str): - dot = Digraph(comment="Genealogy") - node_id = lambda prompt: hash(prompt) - prompt_generations = lambda prompt: [ - str(t) for t in range(len(P)) if prompt in P[t] - ] - node_template = dedent( - """ - {prompt} - ∈ P{{{prompt_generations}}} - {metric_name}: {score}""" - ) - for prompt, parents in family_tree: - node_text = node_template.format( - prompt=fill(str(prompt)), - prompt_generations=", ".join(prompt_generations(prompt)), - score=prompt.score, - metric_name=metric_name, - ) - dot.node( - f"{node_id(prompt)}", - label=node_text, - ) - for prompt, parents in family_tree: - if parents is not None: - # TODO consider more than 2 parents? - dot.edge( - f"{node_id(parents[0])}", - f"{node_id(prompt)}", - ) - dot.edge( - f"{node_id(parents[1])}", - f"{node_id(prompt)}", - ) - - dot.render(run_directory / "family_tree", engine="dot") - - -def save_snapshot(family_tree, P, S, T, N): +def save_snapshot(all_prompts, family_tree, P, S, T, N): import json with open(run_directory / "snapshot.json", "w") as f: json.dump( { + "all_prompts": all_prompts, "family_tree": family_tree, "P": P, "S": S, -- GitLab