Skip to content
Snippets Groups Projects
Commit 28441f6d authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

implement moment-based early stopping

parent d47234e3
No related branches found
No related tags found
No related merge requests found
......@@ -2,11 +2,13 @@ import re
from abc import abstractmethod
from argparse import Namespace
from functools import lru_cache
from statistics import mean
from typing import Union
from datasets import Dataset, load_dataset
from evaluate import load as load_metric
from llama_cpp import LlamaGrammar
from llama_cpp import LlamaGrammar, deque
from torch.utils import data
from tqdm import tqdm
from cli import argument_parser
......@@ -21,6 +23,45 @@ You are given an instruction that describes a task, paired with an input that pr
DatasetDatum = dict
class MomentBasedStopping:
"""
Watch the first derivative (moment) of the metric to determine when to stop.
"""
def __init__(
self,
*,
patience: int = 10,
start_after: int = 20,
min_moment_magnitude: float = 0.001,
):
self.patience = patience
self.start_after = start_after
self.min_moment_magnitude = min_moment_magnitude
self.moment_magnitudes = deque(maxlen=patience)
self.last_score = 0.0
self.num_calls = 0
def update(self, score: float) -> bool:
# caclulate the current moment (dx/dt)
self.moment_magnitudes.append(abs(score - self.last_score))
self.last_score = score
self.num_calls += 1
if (
self.num_calls < self.start_after
or len(self.moment_magnitudes) < self.patience
):
return False
print(mean(self.moment_magnitudes))
if mean(self.moment_magnitudes) < self.min_moment_magnitude:
return True
return False
class Task:
shorthand: str
validation_dataset: Dataset
......@@ -61,6 +102,10 @@ class Task:
pass
def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]:
early_stopping = MomentBasedStopping(
patience=len(dataset) // 20,
start_after=len(dataset) // 5,
)
results: list = []
dataset_iterator: tqdm[DatasetDatum] = tqdm(
dataset, desc="evaluating prompt", leave=False
......@@ -75,6 +120,11 @@ class Task:
{self.metric_name: f"{current_metric*100:.1f}%"}
)
evaluation_usage += usage
if early_stopping.update(current_metric):
logger.info(
f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
)
break
return self._aggregate_result(results), evaluation_usage
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment