From 28441f6d63f5478449d1814f34751dfad34136c5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Thu, 25 Apr 2024 11:40:24 +0200
Subject: [PATCH] implement moment-based early stopping

---
 task.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 51 insertions(+), 1 deletion(-)

diff --git a/task.py b/task.py
index adfaa1f..e474b03 100644
--- a/task.py
+++ b/task.py
@@ -2,11 +2,13 @@ import re
 from abc import abstractmethod
 from argparse import Namespace
 from functools import lru_cache
+from statistics import mean
 from typing import Union
 
 from datasets import Dataset, load_dataset
 from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar
+from llama_cpp import LlamaGrammar, deque
+from torch.utils import data
 from tqdm import tqdm
 
 from cli import argument_parser
@@ -21,6 +23,45 @@ You are given an instruction that describes a task, paired with an input that pr
 DatasetDatum = dict
 
 
+class MomentBasedStopping:
+    """
+    Watch the first derivative (moment) of the metric to determine when to stop.
+    """
+
+    def __init__(
+        self,
+        *,
+        patience: int = 10,
+        start_after: int = 20,
+        min_moment_magnitude: float = 0.001,
+    ):
+        self.patience = patience
+        self.start_after = start_after
+        self.min_moment_magnitude = min_moment_magnitude
+
+        self.moment_magnitudes = deque(maxlen=patience)
+        self.last_score = 0.0
+        self.num_calls = 0
+
+    def update(self, score: float) -> bool:
+        # caclulate the current moment (dx/dt)
+        self.moment_magnitudes.append(abs(score - self.last_score))
+        self.last_score = score
+        self.num_calls += 1
+
+        if (
+            self.num_calls < self.start_after
+            or len(self.moment_magnitudes) < self.patience
+        ):
+            return False
+        print(mean(self.moment_magnitudes))
+
+        if mean(self.moment_magnitudes) < self.min_moment_magnitude:
+            return True
+
+        return False
+
+
 class Task:
     shorthand: str
     validation_dataset: Dataset
@@ -61,6 +102,10 @@ class Task:
         pass
 
     def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]:
+        early_stopping = MomentBasedStopping(
+            patience=len(dataset) // 20,
+            start_after=len(dataset) // 5,
+        )
         results: list = []
         dataset_iterator: tqdm[DatasetDatum] = tqdm(
             dataset, desc="evaluating prompt", leave=False
@@ -75,6 +120,11 @@ class Task:
                 {self.metric_name: f"{current_metric*100:.1f}%"}
             )
             evaluation_usage += usage
+            if early_stopping.update(current_metric):
+                logger.info(
+                    f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
+                )
+                break
 
         return self._aggregate_result(results), evaluation_usage
 
-- 
GitLab