From efd17f2f7fcd4c537019d44ff779148544a99997 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Thu, 7 Mar 2024 10:42:10 +0100
Subject: [PATCH] clean up datamodel to avoid duplication and improve
 interpretability

---
 main.py  | 106 +++++++++++++++++++++++++------------------------------
 task.py  |   7 ++--
 utils.py |  66 ++--------------------------------
 3 files changed, 53 insertions(+), 126 deletions(-)

diff --git a/main.py b/main.py
index b16e646..b65ab4a 100644
--- a/main.py
+++ b/main.py
@@ -60,7 +60,7 @@ def paraphrase_prompts(prompt: str, n: int):
 
 
 @log_calls("Performing selection")
-def selection(prompts, scores):
+def selection(prompts):
     # In GA, two parent solutions are normally selected based on the roulette wheel
     # selection method according to the fitness value (Lipowski & Lipowska, 2012).
     # Similar to this, we utilize the roulette wheel selection method to select
@@ -69,6 +69,7 @@ def selection(prompts, scores):
     # development set of the i-th prompt in the population, which contains a total
     # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
     # pi = si / Σj=1->N sj.
+    scores = [prompt.score for prompt in prompts]
     if sum(scores) == 0:
         # sum of scores is 0 ==> each score is 0, draw with equal probability
         selection_probabilities = len(scores) * [1 / len(scores)]
@@ -134,7 +135,7 @@ def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str
 
 
 @log_calls("Updating prompts")
-def update(prompts: list[str], scores: list[float], N: int):
+def update(prompts: list[str], N: int):
     # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt
     # using a development set, denoted as D, to obtain a score that quantifies the
     # quality of the prompt. We consider a straightforward selection strategy.
@@ -142,18 +143,16 @@ def update(prompts: list[str], scores: list[float], N: int):
     # which are combined with the current population of N prompts.
     # The updated population is then selected by retaining the N prompts with the highest scores.
     retained_prompts = []
-    retained_scores = []
-
-    for prompt, score in zip(prompts, scores):
+    min_retained_score = 0
+    for prompt in prompts:
         if len(retained_prompts) < N:
             retained_prompts.append(prompt)
-            retained_scores.append(score)
-        elif score > min(retained_scores):
-            min_index = retained_scores.index(min(retained_scores))
-            retained_prompts[min_index] = prompt
-            retained_scores[min_index] = score
+            min_retained_score = min(min_retained_score, prompt.score)
+        elif prompt.score > min_retained_score:
+            retained_prompts.sort(key=lambda p: p.score)
+            retained_prompts[0] = prompt
 
-    return retained_prompts, retained_scores
+    return retained_prompts
 
 
 def run_episode(evo_alg_str: str, debug: bool = False):
@@ -178,49 +177,55 @@ def run_episode(evo_alg_str: str, debug: bool = False):
     # Line 1: Initial evaluation scores: S0 ← {si = fD (pi )|i ∈ [1, N ]}
     # the current population's scores
     population_scores = [f_D(p) for p in initial_population]
-    # S keeps track of scores
-    S = [population_scores]
 
-    # P keeps track of prompts and its generations
+    # all_prompts contains a list of Prompt objects that took part in this run at some time
     # converting prompts to Prompt object
-    P = [
-        [
+    all_prompts: dict[str, Prompt] = {
+        prompt.id: prompt
+        for prompt in [
             Prompt(p, score=score, gen=0)
-            for idx, (p, score) in enumerate(zip(initial_population, population_scores))
+            for (p, score) in zip(initial_population, population_scores)
         ]
-    ]
+    }
+
+    # P keeps track of prompts in each generation
+    P = [[prompt_id for prompt_id in all_prompts.keys()]]
 
     # add initial prompts to family tree
     # None marks that there is no parent
     family_tree: dict[str, tuple[str, str] | None] = {
-        prompt.id: None for prompt in P[0]
+        prompt_id: None for prompt_id in P[0]
     }
 
-    # evolution = EvolutionGA(num_evolutions=N)
-
     # Line 2:
     for t in trange(1, T + 1, desc="T", leave=True):
         # Line 3: Selection: select a certain number of prompts from current population as parent prompts
         # pr1,...,prk ∼ Pt−1
+        prompts_current_evolution = [all_prompts[prompt_id] for prompt_id in P[t - 1]]
         if evo_alg_str == "de":
             # DE needs best prompt for evolution
-            best_prompt_current_evolution = P[t - 1][
-                max(range(N), key=lambda i: S[t - 1][i])
-            ]
+            best_prompt_current_evolution = max(
+                prompts_current_evolution, key=lambda prompt: prompt.score
+            )
 
         new_evolutions = []
         new_evolutions_scores = []
 
         for i in trange(N, desc="N", leave=False):
             # for both GA and DE we start with two parent prompts
-            pr1, pr2 = selection(P[t - 1], S[t - 1])
+            pr1, pr2 = selection([all_prompts[prompt_id] for prompt_id in P[t - 1]])
 
             # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
             # p′i ←Evo(pr1,...,prk)
             if evo_alg_str == "ga":
                 p_i = evolution_ga(pr1, pr2)
             elif evo_alg_str == "de":
-                p_i = evolution_de(pr1, pr2, P[t - 1][i], best_prompt_current_evolution)
+                p_i = evolution_de(
+                    pr1,
+                    pr2,
+                    prompts_current_evolution[i],
+                    best_prompt_current_evolution,
+                )
 
             # Line 5: Evaluation
             # s′_i ← f(p′i,D)
@@ -232,52 +237,37 @@ def run_episode(evo_alg_str: str, debug: bool = False):
             family_tree[evolved_prompt.id] = (pr1.id, pr2.id)
 
             new_evolutions.append(evolved_prompt)
-            new_evolutions_scores.append(s_i)
-
+        all_prompts |= {prompt.id: prompt for prompt in new_evolutions}
         # Line 6: Update based on the evaluation scores
         # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
         if evo_alg_str == "ga":
             # GA keeps N best prompts from current population and evolutions
-            population, population_scores = update(
-                new_evolutions + P[t - 1], new_evolutions_scores + S[t - 1], N
-            )
+            population = update(new_evolutions + prompts_current_evolution, N)
         elif evo_alg_str == "de":
             # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
-            assert (
-                len(P[t - 1])
-                == len(S[t - 1])
-                == len(new_evolutions)
-                == len(new_evolutions_scores)
-            )
-            population, population_scores = list(
-                zip(
-                    *[
-                        (
-                            (new_prompt, new_prompt_score)
-                            if new_prompt_score > current_prompt_score
-                            else (current_prompt, current_prompt_score)
-                        )
-                        for current_prompt, current_prompt_score, new_prompt, new_prompt_score in zip(
-                            P[t - 1], S[t - 1], new_evolutions, new_evolutions_scores
-                        )
-                    ]
+            assert len(prompts_current_evolution) == len(new_evolutions)
+            population = [
+                (
+                    new_prompt
+                    if new_prompt.score > current_prompt.score
+                    else current_prompt
                 )
-            )
+                for current_prompt, new_prompt in zip(
+                    prompts_current_evolution, new_evolutions
+                )
+            ]
 
         # store new generation
-        P.append(population)
-        S.append(population_scores)
+        P.append([prompt.id for prompt in population])
 
-    save_snapshot(family_tree, P, S, T, N)
-    save_genealogy(family_tree, P, metric_name=task.metric_name)
-    save_family_tree_visualization(family_tree, P, metric_name=task.metric_name)
+    save_snapshot(all_prompts, family_tree, P, T, N)
     # Line 8: Return the best prompt, p∗, among the final population PT :
     # p∗ ← argmaxp∈PT f(p, D)
-    p = max(range(N), key=lambda i: S[-1][i])
-    logger.info(f"Best prompt: {P[-1][p]}")
+    p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)]
+    logger.info(f"Best prompt: {p}")
 
     # We pick the prompt with the highest score on the development set and report its score on the testset.
-    test_performance = task.evaluate_test(P[-1][p])
+    test_performance = task.evaluate_test(p.content)
     logger.info(f"Best prompt on test set: {test_performance}")
 
 
diff --git a/task.py b/task.py
index c1ba91a..1a18848 100644
--- a/task.py
+++ b/task.py
@@ -1,18 +1,17 @@
+import re
 from abc import abstractmethod
 from collections import defaultdict
 from functools import lru_cache
-import re
 from typing import DefaultDict, Mapping, Union
 
 from datasets import Dataset, load_dataset
 from evaluate import load as load_metric
-from tqdm import tqdm
-from models import Llama2, OpenAI
 from llama_cpp import LlamaGrammar
+from tqdm import tqdm
 
+from models import Llama2, OpenAI
 from utils import log_calls, logger
 
-
 CLASSIFICATION_PROMPT = """
 Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
diff --git a/utils.py b/utils.py
index 7548aa1..17f97fb 100644
--- a/utils.py
+++ b/utils.py
@@ -124,75 +124,13 @@ class PromptEncoder(json.JSONEncoder):
         return json.JSONEncoder.default(self, obj)
 
 
-@log_calls("Saving genealogy")
-def save_genealogy(family_tree, P, metric_name: str):
-    node_id = lambda prompt: hash(prompt)
-    dot = Digraph(comment="Genealogy")
-    for t in range(len(P)):
-        with dot.subgraph(name=f"cluster_{t}") as cluster:
-            cluster.attr(
-                label=f"Generation {t}",
-                color="lightgrey",
-            )
-            for prompt, parents in family_tree:
-                if prompt.gen != t:
-                    continue
-                cluster.node(
-                    f"{node_id(prompt)}",
-                    label=f"{fill(str(prompt))}\n{metric_name}: {prompt.score}",
-                )
-                if parents is not None:
-                    parent_1, parent_2 = parents
-                    dot.edge(f"{node_id(parent_1)}", f"{node_id(prompt)}")
-                    dot.edge(f"{node_id(parent_2)}", f"{node_id(prompt)}")
-
-    dot.render(run_directory / "genealogy", engine="dot")
-
-
-def save_family_tree_visualization(family_tree, P, metric_name: str):
-    dot = Digraph(comment="Genealogy")
-    node_id = lambda prompt: hash(prompt)
-    prompt_generations = lambda prompt: [
-        str(t) for t in range(len(P)) if prompt in P[t]
-    ]
-    node_template = dedent(
-        """
-        {prompt}
-        ∈ P{{{prompt_generations}}}
-        {metric_name}: {score}"""
-    )
-    for prompt, parents in family_tree:
-        node_text = node_template.format(
-            prompt=fill(str(prompt)),
-            prompt_generations=", ".join(prompt_generations(prompt)),
-            score=prompt.score,
-            metric_name=metric_name,
-        )
-        dot.node(
-            f"{node_id(prompt)}",
-            label=node_text,
-        )
-    for prompt, parents in family_tree:
-        if parents is not None:
-            # TODO consider more than 2 parents?
-            dot.edge(
-                f"{node_id(parents[0])}",
-                f"{node_id(prompt)}",
-            )
-            dot.edge(
-                f"{node_id(parents[1])}",
-                f"{node_id(prompt)}",
-            )
-
-    dot.render(run_directory / "family_tree", engine="dot")
-
-
-def save_snapshot(family_tree, P, S, T, N):
+def save_snapshot(all_prompts, family_tree, P, S, T, N):
     import json
 
     with open(run_directory / "snapshot.json", "w") as f:
         json.dump(
             {
+                "all_prompts": all_prompts,
                 "family_tree": family_tree,
                 "P": P,
                 "S": S,
-- 
GitLab