From 7fcaaca757dd2338ee041d047464027db6fada3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Wed, 22 May 2024 09:54:17 +0200
Subject: [PATCH] save snapshot during evolution process

---
 api/optimization.py |   4 +-
 evolution.py        |  71 +++++++-----------------
 frontend            |   2 +-
 main.py             |   3 +-
 optimization.py     | 132 ++++++++++++++++++++++++--------------------
 5 files changed, 99 insertions(+), 113 deletions(-)

diff --git a/api/optimization.py b/api/optimization.py
index de5908d..2241ebc 100644
--- a/api/optimization.py
+++ b/api/optimization.py
@@ -53,6 +53,7 @@ class MultiProcessOptimizer:
             task=task,
             evolution_model=evolution_model,
             evaluation_model=evaluation_model,
+            run_options=options.__dict__,
         )
 
     def __exit__(self, exc_type, exc_value, exc_tb):
@@ -89,7 +90,8 @@ class MultiProcessOptimizer:
 
     def run_optimization(self, num_iterations: int) -> str:
         self._running = True
-        self.optimizer.run(num_iterations, debug=self.debug, add_snapshot_dict={})
+        self.optimizer
+        self.optimizer.run(num_iterations, debug=self.debug)
         self._running = False
 
     def get_progress(self):
diff --git a/evolution.py b/evolution.py
index 1dc8f38..dc133cc 100644
--- a/evolution.py
+++ b/evolution.py
@@ -1,4 +1,5 @@
 from abc import abstractmethod
+from typing import Any
 
 from numpy.random import choice
 from tqdm import trange
@@ -6,9 +7,9 @@ from tqdm import trange
 from cli import argument_parser
 from models import LLMModel
 from opt_types import ModelUsage, Prompt
-from optimization import PromptOptimization, save_snapshot
+from optimization import PromptOptimization
 from task import Task
-from utils import initialize_run_directory, log_calls, logger
+from utils import log_calls, logger
 
 SYSTEM_MESSAGE = (
     "Please follow the instruction step-by-step to generate a better prompt."
@@ -46,11 +47,13 @@ class EvolutionAlgorithm(PromptOptimization):
         task: Task,
         evolution_model: LLMModel,
         evaluation_model: LLMModel,
+        run_options: dict[str, Any] = {},
     ) -> None:
         super().__init__(
             task=task,
             evolution_model=evolution_model,
             evaluation_model=evaluation_model,
+            run_options=run_options,
         )
 
         self.population_size = population_size
@@ -88,46 +91,27 @@ class EvolutionAlgorithm(PromptOptimization):
     def update(self, *args, **kwargs):
         pass
 
-    def run(
-        self, num_iterations: int, add_snapshot_dict: dict, debug: bool = False
-    ) -> None:
+    def run(self, num_iterations: int, debug: bool = False) -> None:
         # debug mode for quick run
         if debug:
             self.population_size = 3
             num_iterations = 2
 
-        # model usage for evaluation of prompts
-        total_evaluation_usage = ModelUsage()
-        # model usage for evolution of prompts
-        total_evolution_usage = ModelUsage()
-
-        run_directory = initialize_run_directory(self.evolution_model)
-
-        initial_prompts, evolution_usage, evaluation_usage = self.init_run(
-            self.population_size
-        )
-        total_evaluation_usage += evaluation_usage
-        total_evolution_usage += evolution_usage
+        self.init_run(self.population_size, num_iterations)
 
         # Algorithm 1 Discrete prompt optimization: EVOPROMPT
 
-        # P keeps track of prompts in each generation
-        P = [initial_prompts]
-
         # Line 2:
-        self.iterations_pbar = trange(
-            1, num_iterations + 1, desc="iterations", leave=True
-        )
         for t in self.iterations_pbar:
             # Line 3: Selection: select a certain number of prompts from current population as parent prompts
             # pr1,...,prk ∼ Pt−1
-            prompts_current_evolution = P[t - 1]
+            prompts_current_evolution = self.P[t - 1]
 
             new_evolutions = []
 
             for i in trange(self.population_size, desc="updates", leave=False):
                 # for both GA and DE we start with two parent prompts
-                pr1, pr2 = self.select(P[t - 1])
+                pr1, pr2 = self.select(self.P[t - 1])
 
                 # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
                 # p′i ←Evo(pr1,...,prk)
@@ -137,38 +121,25 @@ class EvolutionAlgorithm(PromptOptimization):
                     prompts_current_evolution=prompts_current_evolution,
                     current_iteration=i,
                 )
-                total_evolution_usage += evolution_usage
+                self.total_evolution_usage += evolution_usage
 
                 evolved_prompt = self.add_prompt(p_i, (pr1, pr2), {"gen": t})
-                evaluation_usage += evolved_prompt.usage
+                self.total_evaluation_usage += evolved_prompt.usage
 
                 new_evolutions.append(evolved_prompt)
+                self.save_snapshot()
             # Line 6: Update based on the evaluation scores
             # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
             new_population = self.update(new_evolutions, prompts_current_evolution)
 
             # store new generation
-            P.append(new_population)
-
-        # TODO move to super class
-        save_snapshot(
-            run_directory,
-            self.all_prompts,
-            self.family_tree,
-            [[prompt.id for prompt in population] for population in P],
-            num_iterations,
-            self.population_size,
-            self.task,
-            self.evolution_model,
-            # model usage for evaluating prompts
-            total_evaluation_usage,
-            # model usage for evolution of prompts
-            total_evolution_usage,
-            add_snapshot_dict,
-        )
+            self.P.append(new_population)
+            self.save_snapshot()
+
+        self.save_snapshot()
         # Line 8: Return the best prompt, p∗, among the final population PT :
         # p∗ ← argmaxp∈PT f(p, D)
-        p = max(P[-1], key=lambda prompt: self.all_prompts[prompt.id].score)
+        p = max(self.P[-1], key=lambda prompt: self.all_prompts[prompt.id].score)
         logger.info(f"Best prompt: {p}")
 
         # We pick the prompt with the highest score on the development set and report its score on the testset.
@@ -176,12 +147,12 @@ class EvolutionAlgorithm(PromptOptimization):
         logger.info("Best prompt on test set: %s", test_performance)
         logger.info(
             "Usage (evolution model / evaluation model / total): %s / %s / %s",
-            total_evolution_usage,
-            total_evaluation_usage,
-            total_evolution_usage + total_evaluation_usage,
+            self.total_evolution_usage,
+            self.total_evaluation_usage,
+            self.total_evolution_usage + self.total_evaluation_usage,
         )
 
-        return total_evolution_usage, total_evaluation_usage
+        return self.total_evolution_usage, self.total_evaluation_usage
 
 
 class GeneticAlgorithm(EvolutionAlgorithm):
diff --git a/frontend b/frontend
index d430de1..6e7b5ed 160000
--- a/frontend
+++ b/frontend
@@ -1 +1 @@
-Subproject commit d430de1597342eedf0cede1873507a3ffaa28dbb
+Subproject commit 6e7b5edc0e5b34fe100f8b2f46c0117d861c90ee
diff --git a/main.py b/main.py
index 4ca5c51..2a9a7dd 100644
--- a/main.py
+++ b/main.py
@@ -74,5 +74,6 @@ if __name__ == "__main__":
         task=task,
         evolution_model=evolution_model,
         evaluation_model=evaluation_model,
+        run_options=options.__dict__,
     )
-    optimizer.run(10, debug=debug, add_snapshot_dict=options.__dict__)
+    optimizer.run(10, debug=debug)
diff --git a/optimization.py b/optimization.py
index f5592fc..608a316 100644
--- a/optimization.py
+++ b/optimization.py
@@ -1,12 +1,15 @@
 import json
+from abc import abstractmethod
 from itertools import zip_longest
 from pathlib import Path
 from typing import Any
 
+from tqdm import trange
+
 from models import Llama2, LLMModel, OpenAI
 from opt_types import ModelUsage, OptTypeEncoder, Prompt
 from task import Task
-from utils import log_calls
+from utils import initialize_run_directory, log_calls
 
 PARAPHRASE_PROMPT = """You are given an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>."""
 
@@ -40,22 +43,28 @@ def paraphrase_prompts(
 
 
 class PromptOptimization:
+    total_evaluation_usage: ModelUsage
+    total_evolution_usage: ModelUsage
+    run_directory: Path
+    # P contains the list of prompts at each generation
+    P: list[list[Prompt]]
+    # family_tree contains the relation of prompts to its parents
+    family_tree: dict[str, tuple[str, ...] | None]
+    # all_prompts contains a list of Prompt objects that took part in the optimization
+    all_prompts: dict[str, Prompt]
+
     def __init__(
-        self, *, task: Task, evolution_model: LLMModel, evaluation_model: LLMModel
+        self,
+        *,
+        task: Task,
+        evolution_model: LLMModel,
+        evaluation_model: LLMModel,
+        run_options: dict[str, Any] = {}
     ) -> None:
         self.task = task
         self.evolution_model = evolution_model
         self.evaluation_model = evaluation_model
-        self._init()
-
-    def _init(self):
-        # family_tree contains the relation of prompts to its parents
-        self.family_tree: dict[str, tuple[str, ...] | None] = {}
-        # all_prompts contains a list of Prompt objects that took part in the optimization
-        self.all_prompts: dict[str, Prompt] = {}
-
-    def reset(self):
-        self._init
+        self.run_options = run_options
 
     def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None):
         parent_histories = (
@@ -85,6 +94,7 @@ class PromptOptimization:
         self.family_tree[prompt_object.id] = (
             tuple(p.id for p in parents) if parents is not None else None
         )
+        self.save_snapshot()
 
         return prompt_object
 
@@ -105,67 +115,69 @@ class PromptOptimization:
     def get_prompts(self, prompt_ids: list[str]):
         return [self.get_prompt(p_id) for p_id in prompt_ids]
 
-    def init_run(
-        self, num_initial_prompts: int
-    ) -> tuple[list[Prompt], ModelUsage, ModelUsage]:
-        # - Initial prompts P0 = {p1, p2, . . . , pN }
+    @abstractmethod
+    def save_snapshot(self): ...
+
+    def init_run(self, num_initial_prompts: int, num_iterations: int):
+        # family_tree contains the relation of prompts to its parents
+        self.family_tree: dict[str, tuple[str, ...] | None] = {}
+        # all_prompts contains a list of Prompt objects that took part in the optimization
+        self.all_prompts: dict[str, Prompt] = {}
+        self.P = []
+        self.total_evaluation_usage = ModelUsage()
+        self.total_evolution_usage = ModelUsage()
+        self.iterations_pbar = trange(
+            1, num_iterations + 1, desc="iterations", leave=True
+        )
+
+        self.run_directory = initialize_run_directory(self.evolution_model)
+        self.save_snapshot()
+
         paraphrases, paraphrase_usage = paraphrase_prompts(
             self.evolution_model, self.task.base_prompt, n=num_initial_prompts - 1
         )
+        self.total_evolution_usage += paraphrase_usage
 
         # the initial prompts
         initial_prompts = [self.task.base_prompt] + paraphrases
         initial_prompts = self.add_prompts(
             initial_prompts, metas=[{"gen": 0} for _ in initial_prompts]
         )
+        # - Initial prompts P0 = {p1, p2, . . . , pN }
+        self.P.append(initial_prompts)
 
         # accumulate usage
-        evaluation_usage = ModelUsage()
         for prompt in initial_prompts:
-            evaluation_usage += prompt.usage
-
-        return initial_prompts, paraphrase_usage, evaluation_usage
-
-
-# TODO turn snapshots methods into instance methods of optimizer
-def save_snapshot(
-    run_directory: Path,
-    all_prompts: list[Prompt],
-    family_tree: dict[str, tuple[str, str] | None],
-    P: list[list[str]],
-    T: int,
-    N: int,
-    task,
-    model: Llama2 | OpenAI,
-    evaluation_usage: ModelUsage,
-    evolution_usage: ModelUsage,
-    run_options: dict[str, Any],
-):
-
-    with open(run_directory / "snapshot.json", "w") as f:
-        json.dump(
-            {
-                "all_prompts": all_prompts,
-                "family_tree": family_tree,
-                "P": P,
-                "T": T,
-                "N": N,
-                "task": {
-                    "name": task.__class__.__name__,
-                    "validation_dataset": task.validation_dataset.info.dataset_name,
-                    "test_dataset": task.test_dataset.info.dataset_name,
-                    "metric": task.metric_name,
-                    "use_grammar": task.use_grammar,
+            self.total_evaluation_usage += prompt.usage
+        self.save_snapshot()
+
+    def save_snapshot(self):
+        with open(self.run_directory / "snapshot.json", "w") as f:
+            json.dump(
+                {
+                    "all_prompts": self.all_prompts,
+                    "family_tree": self.family_tree,
+                    "P": [
+                        [prompt.id for prompt in population] for population in self.P
+                    ],
+                    "T": self.family_tree,
+                    "N": self.iterations_pbar.n,
+                    "task": {
+                        "name": self.task.__class__.__name__,
+                        "validation_dataset": self.task.validation_dataset.info.dataset_name,
+                        "test_dataset": self.task.test_dataset.info.dataset_name,
+                        "metric": self.task.metric_name,
+                        "use_grammar": self.task.use_grammar,
+                    },
+                    "model": {"name": self.evolution_model.__class__.__name__},
+                    "evaluation_usage": self.total_evaluation_usage,
+                    "evolution_usage": self.total_evolution_usage,
+                    "run_options": self.run_options,
                 },
-                "model": {"name": model.__class__.__name__},
-                "evaluation_usage": evaluation_usage,
-                "evolution_usage": evolution_usage,
-                "run_options": run_options,
-            },
-            f,
-            indent=4,
-            cls=OptTypeEncoder,
-        )
+                f,
+                indent=4,
+                cls=OptTypeEncoder,
+            )
 
 
 def load_snapshot(path: Path):
-- 
GitLab