From b77a62b7b215ef1cc641b0730e4cc93b414f16c6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Thu, 25 Apr 2024 13:28:40 +0200
Subject: [PATCH] added parent based early stopping monitor

---
 evolution.py    |   2 +-
 opt_types.py    |   1 +
 optimization.py |  30 ++++++++++----
 task.py         | 102 +++++++++++++++++++++++++++++++++++++++---------
 4 files changed, 108 insertions(+), 27 deletions(-)

diff --git a/evolution.py b/evolution.py
index 4bf51eb..d2c3e1f 100644
--- a/evolution.py
+++ b/evolution.py
@@ -169,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization):
         logger.info(f"Best prompt: {p}")
 
         # We pick the prompt with the highest score on the development set and report its score on the testset.
-        test_performance, _ = self.task.evaluate_test(p.content)
+        test_performance, _, _ = self.task.evaluate_test(p.content)
         logger.info("Best prompt on test set: %s", test_performance)
         logger.info(
             "Usage (evolution model / evaluation model / total): %s / %s / %s",
diff --git a/opt_types.py b/opt_types.py
index 3d46de8..a80c850 100644
--- a/opt_types.py
+++ b/opt_types.py
@@ -30,6 +30,7 @@ class Prompt:
     content: str
     score: float
     usage: ModelUsage
+    evaluation_history: list[float]
     meta: dict = field(default_factory=dict)
     id: str = field(default_factory=lambda: uuid4().hex)
 
diff --git a/optimization.py b/optimization.py
index 90d6e2e..7819125 100644
--- a/optimization.py
+++ b/optimization.py
@@ -58,22 +58,36 @@ class PromptOptimization:
     def reset(self):
         self._init
 
-    def evaluate_prompt(self, prompt: str):
-        return self.task.evaluate_validation(prompt)
+    def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None):
+        parent_histories = (
+            [parent.evaluation_history for parent in parents]
+            if parents is not None
+            else None
+        )
+        return self.task.evaluate_validation(prompt, parent_histories)
 
     def add_prompt(
-        self, prompt: str, parents: tuple[Prompt] = None, meta: dict = None
+        self,
+        prompt: str,
+        parents: tuple[Prompt] | None = None,
+        meta: dict | None = None,
     ) -> Prompt:
-        score, usage = self.evaluate_prompt(prompt)
-        prompt = Prompt(content=prompt, score=score, meta=meta, usage=usage)
+        score, usage, history = self.evaluate_prompt(prompt, parents)
+        prompt_object = Prompt(
+            content=prompt,
+            score=score,
+            meta=meta if meta is not None else {},
+            usage=usage,
+            evaluation_history=history,
+        )
 
         # keep track of prompt
-        self.all_prompts[prompt.id] = prompt
-        self.family_tree[prompt.id] = (
+        self.all_prompts[prompt_object.id] = prompt_object
+        self.family_tree[prompt_object.id] = (
             tuple(p.id for p in parents) if parents is not None else None
         )
 
-        return prompt
+        return prompt_object
 
     def add_prompts(
         self,
diff --git a/task.py b/task.py
index e474b03..d062253 100644
--- a/task.py
+++ b/task.py
@@ -23,7 +23,14 @@ You are given an instruction that describes a task, paired with an input that pr
 DatasetDatum = dict
 
 
-class MomentBasedStopping:
+class EarlyStoppingMonitor:
+
+    @abstractmethod
+    def update(self, score: float) -> bool:
+        raise NotImplementedError
+
+
+class MomentBasedStopping(EarlyStoppingMonitor):
     """
     Watch the first derivative (moment) of the metric to determine when to stop.
     """
@@ -45,18 +52,60 @@ class MomentBasedStopping:
 
     def update(self, score: float) -> bool:
         # caclulate the current moment (dx/dt)
+        self.num_calls += 1
+        if self.num_calls < self.start_after:
+            return False
+
         self.moment_magnitudes.append(abs(score - self.last_score))
         self.last_score = score
+        if len(self.moment_magnitudes) < self.patience:
+            return False
+
+        if mean(self.moment_magnitudes) < self.min_moment_magnitude:
+            return True
+
+        return False
+
+
+class ParentBaselineBasedStopping(EarlyStoppingMonitor):
+
+    def __init__(
+        self,
+        parent_histories: list[list[float]],
+        *,
+        patience: int = 10,
+        start_after: int = 20,
+        min_improvement: float = 0.001,
+    ):
+        self.parent_histories = parent_histories
+        self.patience = patience
+        self.start_after = start_after
+        self.min_improvement = min_improvement
+        self.num_calls = 0
+        self.improvement_memory = deque(maxlen=patience)
+
+    def update(self, score: float) -> bool:
         self.num_calls += 1
+        if self.num_calls < self.start_after:
+            return False
 
-        if (
-            self.num_calls < self.start_after
-            or len(self.moment_magnitudes) < self.patience
-        ):
+        parent_values = [  # get the metric value of the parents at the current step
+            (
+                parent_history[self.num_calls - 1]
+                if len(parent_history) >= self.num_calls
+                else parent_history[-1]  # extend with last value
+            )
+            for parent_history in self.parent_histories
+        ]
+        self.improvement_memory.append(
+            score - max(parent_values)  # compare with the best parent
+        )
+
+        if len(self.improvement_memory) < self.patience:
             return False
-        print(mean(self.moment_magnitudes))
 
-        if mean(self.moment_magnitudes) < self.min_moment_magnitude:
+        if max(self.improvement_memory) < self.min_improvement:
+            # if the highest improvement is less than the minimum improvement, we stop
             return True
 
         return False
@@ -101,16 +150,31 @@ class Task:
     def _aggregate_result(self, results: list) -> float:
         pass
 
-    def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]:
-        early_stopping = MomentBasedStopping(
-            patience=len(dataset) // 20,
-            start_after=len(dataset) // 5,
-        )
+    def evaluate(
+        self,
+        prompt: str,
+        dataset: Dataset,
+        parent_histories: list[list[float]] | None = None,
+    ) -> tuple[float, ModelUsage, list[float]]:
+
+        early_stopping: EarlyStoppingMonitor
+        early_stopping_params = {
+            "patience": max(len(dataset) // 20, 5),
+            "start_after": max(len(dataset) // 5, 5),
+        }
+        if parent_histories is not None:
+            early_stopping = ParentBaselineBasedStopping(
+                parent_histories, **early_stopping_params
+            )
+        else:
+            early_stopping = MomentBasedStopping(**early_stopping_params)
+
         results: list = []
         dataset_iterator: tqdm[DatasetDatum] = tqdm(
             dataset, desc="evaluating prompt", leave=False
         )
         evaluation_usage = ModelUsage()
+        evaluation_history = []
 
         for datum in dataset_iterator:
             result, usage = self._evaluate_sample(prompt, datum)
@@ -120,18 +184,20 @@ class Task:
                 {self.metric_name: f"{current_metric*100:.1f}%"}
             )
             evaluation_usage += usage
+            evaluation_history.append(current_metric)
             if early_stopping.update(current_metric):
                 logger.info(
                     f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
                 )
                 break
 
-        return self._aggregate_result(results), evaluation_usage
+        return self._aggregate_result(results), evaluation_usage, evaluation_history
 
     @log_calls("Evaluating validation dataset")
-    @lru_cache(maxsize=None)
-    def evaluate_validation(self, prompt: str):
-        return self.evaluate(prompt, self.validation_dataset)
+    def evaluate_validation(
+        self, prompt: str, parent_histories: list[list[float]] | None = None
+    ):
+        return self.evaluate(prompt, self.validation_dataset, parent_histories)
 
     @log_calls("Evaluating test dataset")
     def evaluate_test(self, prompt: str):
@@ -207,7 +273,7 @@ class SentimentAnalysis(Task):
                     break
             else:
                 logger.warning(f"Invalid answer: {response}")
-                return "failed"
+                return "failed", usage
 
         classification_result = (
             "incorrect" if answer_label != datum["label"] else "correct"
@@ -333,7 +399,7 @@ class QuestionAnswering(Task):
     def _aggregate_result(self, results: list[float]) -> float:
         return sum(results) / len(results)
 
-    def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]:
+    def evaluate(self, prompt: str, dataset: Dataset):
         def replace_symbol_for_grammar(sample: DatasetDatum):
             symbol_replacement_mapping = {
                 "\u2013": "-",
-- 
GitLab