From eb8df827c6b4cd1505080cb2f348125a8406c584 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Mon, 27 May 2024 07:30:33 +0200
Subject: [PATCH] add option to disable early stopping

---
 evolution.py |  9 +++------
 task.py      | 14 ++++++++++----
 2 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/evolution.py b/evolution.py
index dc133cc..25b838f 100644
--- a/evolution.py
+++ b/evolution.py
@@ -68,12 +68,9 @@ class EvolutionAlgorithm(PromptOptimization):
         # development set of the i-th prompt in the population, which contains a total
         # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
         # pi = si / Σj=1->N sj.
-        scores = [prompt.score for prompt in prompts]
-        if sum(scores) == 0:
-            # sum of scores is 0 ==> each score is 0, draw with equal probability
-            selection_probabilities = len(scores) * [1 / len(scores)]
-        else:
-            selection_probabilities = [score / sum(scores) for score in scores]
+        # add small value to avoid zero chance of selection for some prompts
+        scores = [prompt.score + 1e-6 for prompt in prompts]
+        selection_probabilities = [score / sum(scores) for score in scores]
         return choice(prompts, size=2, replace=False, p=selection_probabilities)
 
     @abstractmethod
diff --git a/task.py b/task.py
index abb26c3..40d8fe1 100644
--- a/task.py
+++ b/task.py
@@ -5,13 +5,14 @@ from functools import lru_cache
 from statistics import mean
 from typing import Union
 
-from cli import argument_parser
 from datasets import Dataset, load_dataset
 from evaluate import load as load_metric
 from llama_cpp import LlamaGrammar, deque
+from tqdm import tqdm
+
+from cli import argument_parser
 from models import Llama2, LLMModel, OpenAI
 from opt_types import ModelUsage
-from tqdm import tqdm
 from utils import log_calls, logger
 
 SYSTEM_MESSAGE = """
@@ -121,6 +122,7 @@ class Task:
         test_dataset: str,
         *,
         use_grammar: bool,
+        no_early_stopping: bool = False,
         validation_split: str | None = None,
         test_split: str | None = None,
     ) -> None:
@@ -128,6 +130,7 @@ class Task:
 
         # whether we use the grammar to constrain the model output or not
         self.use_grammar = use_grammar
+        self.no_early_stopping = no_early_stopping
 
         self.validation_dataset = load_dataset(
             validation_dataset, split=validation_split
@@ -183,7 +186,7 @@ class Task:
             )
             evaluation_usage += usage
             evaluation_history.append(current_metric)
-            if early_stopping.update(current_metric):
+            if not self.no_early_stopping and early_stopping.update(current_metric):
                 logger.info(
                     f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
                 )
@@ -234,6 +237,7 @@ class SentimentAnalysis(Task):
             validation_dataset="SetFit/sst2",
             test_dataset="SetFit/sst2",
             use_grammar=options.use_grammar,
+            no_early_stopping=options.no_early_stopping,
             validation_split=f"validation[:{5 if options.debug else 200}]",
             test_split="test[:20]" if options.debug else "test",
         )
@@ -339,6 +343,7 @@ class QuestionAnswering(Task):
             "squad",
             "squad",
             use_grammar=options.use_grammar,
+            no_early_stopping=options.no_early_stopping,
             validation_split=f"train[:{5 if options.debug else 200}]",
             test_split="validation[:20]" if options.debug else "validation",
         )
@@ -445,6 +450,7 @@ def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
 argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
 argument_group = argument_parser.add_argument_group("Task arguments")
 argument_group.add_argument("--use-grammar", "-g", action="store_true")
+argument_group.add_argument("--no-early-stopping", action="store_true")
 argument_group.add_argument(
-    "--task", "-t", type=str, required=True, choices=["sa", "qa"]
+    "--task", "-t", type=str, required=True, choices=tasks.keys()
 )
-- 
GitLab