diff --git a/main.py b/main.py index 39b8e04df4f878cbb3b8bfaba30aa305d637189e..c7af5478b40d036fdba5f1a7ab2d46689b4f923f 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ Below is an instruction that describes a task. Write a response that paraphrases def evaluate_prompt(prompt: str, dataset: Dataset): sst2_labels = {"negative": 0, "positive": 1} - results = DefaultDict(int) + results: DefaultDict[str, int] = DefaultDict(int) dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False) for datum in dataset_iterator: @@ -211,7 +211,7 @@ if __name__ == "__main__": save_snapshot(family_tree, P, S, T, N) save_genealogy(family_tree, P, S, T) - save_family_tree_visualization(family_tree, P, S) + save_family_tree_visualization(family_tree, P, S, T) # Line 8: Return the best prompt, p‚àó, among the final population PT : # p‚àó ‚Üê argmaxp‚ààPT f(p, D) p = max(range(N), key=lambda i: S[T][i]) diff --git a/utils.py b/utils.py index 24af6a62468f0afbde5faad2920b494c056d4f61..fe4ef1314adfb081a58d1774a18f070a5d1d3b22 100644 --- a/utils.py +++ b/utils.py @@ -126,10 +126,11 @@ def save_genealogy(family_tree, P, S, T): dot.render(run_directory / "genealogy", engine="dot") -def save_family_tree_visualization(family_tree, P, S): +def save_family_tree_visualization(family_tree, P, S, T): dot = Digraph(comment="Genealogy") node_id = lambda prompt: shake_256(prompt.encode("utf-8")).hexdigest(5) flatten = lambda l: [item for sublist in l for item in sublist] + prompt_generations = lambda prompt: [str(t) for t in range(T) if prompt in P[t]] all_prompts = set(P[0] + list(family_tree.keys())) scores = { prompt: score @@ -138,10 +139,21 @@ def save_family_tree_visualization(family_tree, P, S): flatten(S), ) } + node_template = dedent( + """ + {prompt} + ‚àà P{{{prompt_generations}}} + Accuracy: {score}""" + ) for prompt in all_prompts: + node_text = node_template.format( + prompt=fill(prompt), + prompt_generations=", ".join(prompt_generations(prompt)), + score=scores[prompt], + ) dot.node( f"{node_id(prompt)}", - label=f"{fill(prompt)}\n Accuracy: {scores[prompt]}", + label=node_text, ) for prompt, parents in family_tree.items(): dot.edge(