From 28441f6d63f5478449d1814f34751dfad34136c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Thu, 25 Apr 2024 11:40:24 +0200 Subject: [PATCH] implement moment-based early stopping --- task.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/task.py b/task.py index adfaa1f..e474b03 100644 --- a/task.py +++ b/task.py @@ -2,11 +2,13 @@ import re from abc import abstractmethod from argparse import Namespace from functools import lru_cache +from statistics import mean from typing import Union from datasets import Dataset, load_dataset from evaluate import load as load_metric -from llama_cpp import LlamaGrammar +from llama_cpp import LlamaGrammar, deque +from torch.utils import data from tqdm import tqdm from cli import argument_parser @@ -21,6 +23,45 @@ You are given an instruction that describes a task, paired with an input that pr DatasetDatum = dict +class MomentBasedStopping: + """ + Watch the first derivative (moment) of the metric to determine when to stop. + """ + + def __init__( + self, + *, + patience: int = 10, + start_after: int = 20, + min_moment_magnitude: float = 0.001, + ): + self.patience = patience + self.start_after = start_after + self.min_moment_magnitude = min_moment_magnitude + + self.moment_magnitudes = deque(maxlen=patience) + self.last_score = 0.0 + self.num_calls = 0 + + def update(self, score: float) -> bool: + # caclulate the current moment (dx/dt) + self.moment_magnitudes.append(abs(score - self.last_score)) + self.last_score = score + self.num_calls += 1 + + if ( + self.num_calls < self.start_after + or len(self.moment_magnitudes) < self.patience + ): + return False + print(mean(self.moment_magnitudes)) + + if mean(self.moment_magnitudes) < self.min_moment_magnitude: + return True + + return False + + class Task: shorthand: str validation_dataset: Dataset @@ -61,6 +102,10 @@ class Task: pass def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]: + early_stopping = MomentBasedStopping( + patience=len(dataset) // 20, + start_after=len(dataset) // 5, + ) results: list = [] dataset_iterator: tqdm[DatasetDatum] = tqdm( dataset, desc="evaluating prompt", leave=False @@ -75,6 +120,11 @@ class Task: {self.metric_name: f"{current_metric*100:.1f}%"} ) evaluation_usage += usage + if early_stopping.update(current_metric): + logger.info( + f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%" + ) + break return self._aggregate_result(results), evaluation_usage -- GitLab