From b77a62b7b215ef1cc641b0730e4cc93b414f16c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Thu, 25 Apr 2024 13:28:40 +0200 Subject: [PATCH] added parent based early stopping monitor --- evolution.py | 2 +- opt_types.py | 1 + optimization.py | 30 ++++++++++---- task.py | 102 +++++++++++++++++++++++++++++++++++++++--------- 4 files changed, 108 insertions(+), 27 deletions(-) diff --git a/evolution.py b/evolution.py index 4bf51eb..d2c3e1f 100644 --- a/evolution.py +++ b/evolution.py @@ -169,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization): logger.info(f"Best prompt: {p}") # We pick the prompt with the highest score on the development set and report its score on the testset. - test_performance, _ = self.task.evaluate_test(p.content) + test_performance, _, _ = self.task.evaluate_test(p.content) logger.info("Best prompt on test set: %s", test_performance) logger.info( "Usage (evolution model / evaluation model / total): %s / %s / %s", diff --git a/opt_types.py b/opt_types.py index 3d46de8..a80c850 100644 --- a/opt_types.py +++ b/opt_types.py @@ -30,6 +30,7 @@ class Prompt: content: str score: float usage: ModelUsage + evaluation_history: list[float] meta: dict = field(default_factory=dict) id: str = field(default_factory=lambda: uuid4().hex) diff --git a/optimization.py b/optimization.py index 90d6e2e..7819125 100644 --- a/optimization.py +++ b/optimization.py @@ -58,22 +58,36 @@ class PromptOptimization: def reset(self): self._init - def evaluate_prompt(self, prompt: str): - return self.task.evaluate_validation(prompt) + def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None): + parent_histories = ( + [parent.evaluation_history for parent in parents] + if parents is not None + else None + ) + return self.task.evaluate_validation(prompt, parent_histories) def add_prompt( - self, prompt: str, parents: tuple[Prompt] = None, meta: dict = None + self, + prompt: str, + parents: tuple[Prompt] | None = None, + meta: dict | None = None, ) -> Prompt: - score, usage = self.evaluate_prompt(prompt) - prompt = Prompt(content=prompt, score=score, meta=meta, usage=usage) + score, usage, history = self.evaluate_prompt(prompt, parents) + prompt_object = Prompt( + content=prompt, + score=score, + meta=meta if meta is not None else {}, + usage=usage, + evaluation_history=history, + ) # keep track of prompt - self.all_prompts[prompt.id] = prompt - self.family_tree[prompt.id] = ( + self.all_prompts[prompt_object.id] = prompt_object + self.family_tree[prompt_object.id] = ( tuple(p.id for p in parents) if parents is not None else None ) - return prompt + return prompt_object def add_prompts( self, diff --git a/task.py b/task.py index e474b03..d062253 100644 --- a/task.py +++ b/task.py @@ -23,7 +23,14 @@ You are given an instruction that describes a task, paired with an input that pr DatasetDatum = dict -class MomentBasedStopping: +class EarlyStoppingMonitor: + + @abstractmethod + def update(self, score: float) -> bool: + raise NotImplementedError + + +class MomentBasedStopping(EarlyStoppingMonitor): """ Watch the first derivative (moment) of the metric to determine when to stop. """ @@ -45,18 +52,60 @@ class MomentBasedStopping: def update(self, score: float) -> bool: # caclulate the current moment (dx/dt) + self.num_calls += 1 + if self.num_calls < self.start_after: + return False + self.moment_magnitudes.append(abs(score - self.last_score)) self.last_score = score + if len(self.moment_magnitudes) < self.patience: + return False + + if mean(self.moment_magnitudes) < self.min_moment_magnitude: + return True + + return False + + +class ParentBaselineBasedStopping(EarlyStoppingMonitor): + + def __init__( + self, + parent_histories: list[list[float]], + *, + patience: int = 10, + start_after: int = 20, + min_improvement: float = 0.001, + ): + self.parent_histories = parent_histories + self.patience = patience + self.start_after = start_after + self.min_improvement = min_improvement + self.num_calls = 0 + self.improvement_memory = deque(maxlen=patience) + + def update(self, score: float) -> bool: self.num_calls += 1 + if self.num_calls < self.start_after: + return False - if ( - self.num_calls < self.start_after - or len(self.moment_magnitudes) < self.patience - ): + parent_values = [ # get the metric value of the parents at the current step + ( + parent_history[self.num_calls - 1] + if len(parent_history) >= self.num_calls + else parent_history[-1] # extend with last value + ) + for parent_history in self.parent_histories + ] + self.improvement_memory.append( + score - max(parent_values) # compare with the best parent + ) + + if len(self.improvement_memory) < self.patience: return False - print(mean(self.moment_magnitudes)) - if mean(self.moment_magnitudes) < self.min_moment_magnitude: + if max(self.improvement_memory) < self.min_improvement: + # if the highest improvement is less than the minimum improvement, we stop return True return False @@ -101,16 +150,31 @@ class Task: def _aggregate_result(self, results: list) -> float: pass - def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]: - early_stopping = MomentBasedStopping( - patience=len(dataset) // 20, - start_after=len(dataset) // 5, - ) + def evaluate( + self, + prompt: str, + dataset: Dataset, + parent_histories: list[list[float]] | None = None, + ) -> tuple[float, ModelUsage, list[float]]: + + early_stopping: EarlyStoppingMonitor + early_stopping_params = { + "patience": max(len(dataset) // 20, 5), + "start_after": max(len(dataset) // 5, 5), + } + if parent_histories is not None: + early_stopping = ParentBaselineBasedStopping( + parent_histories, **early_stopping_params + ) + else: + early_stopping = MomentBasedStopping(**early_stopping_params) + results: list = [] dataset_iterator: tqdm[DatasetDatum] = tqdm( dataset, desc="evaluating prompt", leave=False ) evaluation_usage = ModelUsage() + evaluation_history = [] for datum in dataset_iterator: result, usage = self._evaluate_sample(prompt, datum) @@ -120,18 +184,20 @@ class Task: {self.metric_name: f"{current_metric*100:.1f}%"} ) evaluation_usage += usage + evaluation_history.append(current_metric) if early_stopping.update(current_metric): logger.info( f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%" ) break - return self._aggregate_result(results), evaluation_usage + return self._aggregate_result(results), evaluation_usage, evaluation_history @log_calls("Evaluating validation dataset") - @lru_cache(maxsize=None) - def evaluate_validation(self, prompt: str): - return self.evaluate(prompt, self.validation_dataset) + def evaluate_validation( + self, prompt: str, parent_histories: list[list[float]] | None = None + ): + return self.evaluate(prompt, self.validation_dataset, parent_histories) @log_calls("Evaluating test dataset") def evaluate_test(self, prompt: str): @@ -207,7 +273,7 @@ class SentimentAnalysis(Task): break else: logger.warning(f"Invalid answer: {response}") - return "failed" + return "failed", usage classification_result = ( "incorrect" if answer_label != datum["label"] else "correct" @@ -333,7 +399,7 @@ class QuestionAnswering(Task): def _aggregate_result(self, results: list[float]) -> float: return sum(results) / len(results) - def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]: + def evaluate(self, prompt: str, dataset: Dataset): def replace_symbol_for_grammar(sample: DatasetDatum): symbol_replacement_mapping = { "\u2013": "-", -- GitLab