Skip to content
Snippets Groups Projects
Commit 589cc378 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

added generations a prompt appears in to the family tree

parent 07031d35
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,7 @@ Below is an instruction that describes a task. Write a response that paraphrases
def evaluate_prompt(prompt: str, dataset: Dataset):
sst2_labels = {"negative": 0, "positive": 1}
results = DefaultDict(int)
results: DefaultDict[str, int] = DefaultDict(int)
dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
for datum in dataset_iterator:
......@@ -211,7 +211,7 @@ if __name__ == "__main__":
save_snapshot(family_tree, P, S, T, N)
save_genealogy(family_tree, P, S, T)
save_family_tree_visualization(family_tree, P, S)
save_family_tree_visualization(family_tree, P, S, T)
# Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D)
p = max(range(N), key=lambda i: S[T][i])
......
......@@ -126,10 +126,11 @@ def save_genealogy(family_tree, P, S, T):
dot.render(run_directory / "genealogy", engine="dot")
def save_family_tree_visualization(family_tree, P, S):
def save_family_tree_visualization(family_tree, P, S, T):
dot = Digraph(comment="Genealogy")
node_id = lambda prompt: shake_256(prompt.encode("utf-8")).hexdigest(5)
flatten = lambda l: [item for sublist in l for item in sublist]
prompt_generations = lambda prompt: [str(t) for t in range(T) if prompt in P[t]]
all_prompts = set(P[0] + list(family_tree.keys()))
scores = {
prompt: score
......@@ -138,10 +139,21 @@ def save_family_tree_visualization(family_tree, P, S):
flatten(S),
)
}
node_template = dedent(
"""
{prompt}
∈ P{{{prompt_generations}}}
Accuracy: {score}"""
)
for prompt in all_prompts:
node_text = node_template.format(
prompt=fill(prompt),
prompt_generations=", ".join(prompt_generations(prompt)),
score=scores[prompt],
)
dot.node(
f"{node_id(prompt)}",
label=f"{fill(prompt)}\n Accuracy: {scores[prompt]}",
label=node_text,
)
for prompt, parents in family_tree.items():
dot.edge(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment