Skip to content
Snippets Groups Projects
Commit efd17f2f authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

clean up datamodel to avoid duplication and improve interpretability

parent 316ee981
No related branches found
No related tags found
No related merge requests found
......@@ -60,7 +60,7 @@ def paraphrase_prompts(prompt: str, n: int):
@log_calls("Performing selection")
def selection(prompts, scores):
def selection(prompts):
# In GA, two parent solutions are normally selected based on the roulette wheel
# selection method according to the fitness value (Lipowski & Lipowska, 2012).
# Similar to this, we utilize the roulette wheel selection method to select
......@@ -69,6 +69,7 @@ def selection(prompts, scores):
# development set of the i-th prompt in the population, which contains a total
# of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
# pi = si / Σj=1->N sj.
scores = [prompt.score for prompt in prompts]
if sum(scores) == 0:
# sum of scores is 0 ==> each score is 0, draw with equal probability
selection_probabilities = len(scores) * [1 / len(scores)]
......@@ -134,7 +135,7 @@ def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str
@log_calls("Updating prompts")
def update(prompts: list[str], scores: list[float], N: int):
def update(prompts: list[str], N: int):
# EVOPROMPT iteratively generates new candidate prompts and assesses each prompt
# using a development set, denoted as D, to obtain a score that quantifies the
# quality of the prompt. We consider a straightforward selection strategy.
......@@ -142,18 +143,16 @@ def update(prompts: list[str], scores: list[float], N: int):
# which are combined with the current population of N prompts.
# The updated population is then selected by retaining the N prompts with the highest scores.
retained_prompts = []
retained_scores = []
for prompt, score in zip(prompts, scores):
min_retained_score = 0
for prompt in prompts:
if len(retained_prompts) < N:
retained_prompts.append(prompt)
retained_scores.append(score)
elif score > min(retained_scores):
min_index = retained_scores.index(min(retained_scores))
retained_prompts[min_index] = prompt
retained_scores[min_index] = score
min_retained_score = min(min_retained_score, prompt.score)
elif prompt.score > min_retained_score:
retained_prompts.sort(key=lambda p: p.score)
retained_prompts[0] = prompt
return retained_prompts, retained_scores
return retained_prompts
def run_episode(evo_alg_str: str, debug: bool = False):
......@@ -178,49 +177,55 @@ def run_episode(evo_alg_str: str, debug: bool = False):
# Line 1: Initial evaluation scores: S0 ← {si = fD (pi )|i ∈ [1, N ]}
# the current population's scores
population_scores = [f_D(p) for p in initial_population]
# S keeps track of scores
S = [population_scores]
# P keeps track of prompts and its generations
# all_prompts contains a list of Prompt objects that took part in this run at some time
# converting prompts to Prompt object
P = [
[
all_prompts: dict[str, Prompt] = {
prompt.id: prompt
for prompt in [
Prompt(p, score=score, gen=0)
for idx, (p, score) in enumerate(zip(initial_population, population_scores))
for (p, score) in zip(initial_population, population_scores)
]
]
}
# P keeps track of prompts in each generation
P = [[prompt_id for prompt_id in all_prompts.keys()]]
# add initial prompts to family tree
# None marks that there is no parent
family_tree: dict[str, tuple[str, str] | None] = {
prompt.id: None for prompt in P[0]
prompt_id: None for prompt_id in P[0]
}
# evolution = EvolutionGA(num_evolutions=N)
# Line 2:
for t in trange(1, T + 1, desc="T", leave=True):
# Line 3: Selection: select a certain number of prompts from current population as parent prompts
# pr1,...,prk ∼ Pt−1
prompts_current_evolution = [all_prompts[prompt_id] for prompt_id in P[t - 1]]
if evo_alg_str == "de":
# DE needs best prompt for evolution
best_prompt_current_evolution = P[t - 1][
max(range(N), key=lambda i: S[t - 1][i])
]
best_prompt_current_evolution = max(
prompts_current_evolution, key=lambda prompt: prompt.score
)
new_evolutions = []
new_evolutions_scores = []
for i in trange(N, desc="N", leave=False):
# for both GA and DE we start with two parent prompts
pr1, pr2 = selection(P[t - 1], S[t - 1])
pr1, pr2 = selection([all_prompts[prompt_id] for prompt_id in P[t - 1]])
# Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
# p′i ←Evo(pr1,...,prk)
if evo_alg_str == "ga":
p_i = evolution_ga(pr1, pr2)
elif evo_alg_str == "de":
p_i = evolution_de(pr1, pr2, P[t - 1][i], best_prompt_current_evolution)
p_i = evolution_de(
pr1,
pr2,
prompts_current_evolution[i],
best_prompt_current_evolution,
)
# Line 5: Evaluation
# s′_i ← f(p′i,D)
......@@ -232,52 +237,37 @@ def run_episode(evo_alg_str: str, debug: bool = False):
family_tree[evolved_prompt.id] = (pr1.id, pr2.id)
new_evolutions.append(evolved_prompt)
new_evolutions_scores.append(s_i)
all_prompts |= {prompt.id: prompt for prompt in new_evolutions}
# Line 6: Update based on the evaluation scores
# Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
if evo_alg_str == "ga":
# GA keeps N best prompts from current population and evolutions
population, population_scores = update(
new_evolutions + P[t - 1], new_evolutions_scores + S[t - 1], N
)
population = update(new_evolutions + prompts_current_evolution, N)
elif evo_alg_str == "de":
# for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
assert (
len(P[t - 1])
== len(S[t - 1])
== len(new_evolutions)
== len(new_evolutions_scores)
)
population, population_scores = list(
zip(
*[
(
(new_prompt, new_prompt_score)
if new_prompt_score > current_prompt_score
else (current_prompt, current_prompt_score)
)
for current_prompt, current_prompt_score, new_prompt, new_prompt_score in zip(
P[t - 1], S[t - 1], new_evolutions, new_evolutions_scores
)
]
assert len(prompts_current_evolution) == len(new_evolutions)
population = [
(
new_prompt
if new_prompt.score > current_prompt.score
else current_prompt
)
)
for current_prompt, new_prompt in zip(
prompts_current_evolution, new_evolutions
)
]
# store new generation
P.append(population)
S.append(population_scores)
P.append([prompt.id for prompt in population])
save_snapshot(family_tree, P, S, T, N)
save_genealogy(family_tree, P, metric_name=task.metric_name)
save_family_tree_visualization(family_tree, P, metric_name=task.metric_name)
save_snapshot(all_prompts, family_tree, P, T, N)
# Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D)
p = max(range(N), key=lambda i: S[-1][i])
logger.info(f"Best prompt: {P[-1][p]}")
p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)]
logger.info(f"Best prompt: {p}")
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_performance = task.evaluate_test(P[-1][p])
test_performance = task.evaluate_test(p.content)
logger.info(f"Best prompt on test set: {test_performance}")
......
import re
from abc import abstractmethod
from collections import defaultdict
from functools import lru_cache
import re
from typing import DefaultDict, Mapping, Union
from datasets import Dataset, load_dataset
from evaluate import load as load_metric
from tqdm import tqdm
from models import Llama2, OpenAI
from llama_cpp import LlamaGrammar
from tqdm import tqdm
from models import Llama2, OpenAI
from utils import log_calls, logger
CLASSIFICATION_PROMPT = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
......
......@@ -124,75 +124,13 @@ class PromptEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj)
@log_calls("Saving genealogy")
def save_genealogy(family_tree, P, metric_name: str):
node_id = lambda prompt: hash(prompt)
dot = Digraph(comment="Genealogy")
for t in range(len(P)):
with dot.subgraph(name=f"cluster_{t}") as cluster:
cluster.attr(
label=f"Generation {t}",
color="lightgrey",
)
for prompt, parents in family_tree:
if prompt.gen != t:
continue
cluster.node(
f"{node_id(prompt)}",
label=f"{fill(str(prompt))}\n{metric_name}: {prompt.score}",
)
if parents is not None:
parent_1, parent_2 = parents
dot.edge(f"{node_id(parent_1)}", f"{node_id(prompt)}")
dot.edge(f"{node_id(parent_2)}", f"{node_id(prompt)}")
dot.render(run_directory / "genealogy", engine="dot")
def save_family_tree_visualization(family_tree, P, metric_name: str):
dot = Digraph(comment="Genealogy")
node_id = lambda prompt: hash(prompt)
prompt_generations = lambda prompt: [
str(t) for t in range(len(P)) if prompt in P[t]
]
node_template = dedent(
"""
{prompt}
∈ P{{{prompt_generations}}}
{metric_name}: {score}"""
)
for prompt, parents in family_tree:
node_text = node_template.format(
prompt=fill(str(prompt)),
prompt_generations=", ".join(prompt_generations(prompt)),
score=prompt.score,
metric_name=metric_name,
)
dot.node(
f"{node_id(prompt)}",
label=node_text,
)
for prompt, parents in family_tree:
if parents is not None:
# TODO consider more than 2 parents?
dot.edge(
f"{node_id(parents[0])}",
f"{node_id(prompt)}",
)
dot.edge(
f"{node_id(parents[1])}",
f"{node_id(prompt)}",
)
dot.render(run_directory / "family_tree", engine="dot")
def save_snapshot(family_tree, P, S, T, N):
def save_snapshot(all_prompts, family_tree, P, S, T, N):
import json
with open(run_directory / "snapshot.json", "w") as f:
json.dump(
{
"all_prompts": all_prompts,
"family_tree": family_tree,
"P": P,
"S": S,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment