From eb8df827c6b4cd1505080cb2f348125a8406c584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Mon, 27 May 2024 07:30:33 +0200 Subject: [PATCH] add option to disable early stopping --- evolution.py | 9 +++------ task.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/evolution.py b/evolution.py index dc133cc..25b838f 100644 --- a/evolution.py +++ b/evolution.py @@ -68,12 +68,9 @@ class EvolutionAlgorithm(PromptOptimization): # development set of the i-th prompt in the population, which contains a total # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as # pi = si / Σj=1->N sj. - scores = [prompt.score for prompt in prompts] - if sum(scores) == 0: - # sum of scores is 0 ==> each score is 0, draw with equal probability - selection_probabilities = len(scores) * [1 / len(scores)] - else: - selection_probabilities = [score / sum(scores) for score in scores] + # add small value to avoid zero chance of selection for some prompts + scores = [prompt.score + 1e-6 for prompt in prompts] + selection_probabilities = [score / sum(scores) for score in scores] return choice(prompts, size=2, replace=False, p=selection_probabilities) @abstractmethod diff --git a/task.py b/task.py index abb26c3..40d8fe1 100644 --- a/task.py +++ b/task.py @@ -5,13 +5,14 @@ from functools import lru_cache from statistics import mean from typing import Union -from cli import argument_parser from datasets import Dataset, load_dataset from evaluate import load as load_metric from llama_cpp import LlamaGrammar, deque +from tqdm import tqdm + +from cli import argument_parser from models import Llama2, LLMModel, OpenAI from opt_types import ModelUsage -from tqdm import tqdm from utils import log_calls, logger SYSTEM_MESSAGE = """ @@ -121,6 +122,7 @@ class Task: test_dataset: str, *, use_grammar: bool, + no_early_stopping: bool = False, validation_split: str | None = None, test_split: str | None = None, ) -> None: @@ -128,6 +130,7 @@ class Task: # whether we use the grammar to constrain the model output or not self.use_grammar = use_grammar + self.no_early_stopping = no_early_stopping self.validation_dataset = load_dataset( validation_dataset, split=validation_split @@ -183,7 +186,7 @@ class Task: ) evaluation_usage += usage evaluation_history.append(current_metric) - if early_stopping.update(current_metric): + if not self.no_early_stopping and early_stopping.update(current_metric): logger.info( f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%" ) @@ -234,6 +237,7 @@ class SentimentAnalysis(Task): validation_dataset="SetFit/sst2", test_dataset="SetFit/sst2", use_grammar=options.use_grammar, + no_early_stopping=options.no_early_stopping, validation_split=f"validation[:{5 if options.debug else 200}]", test_split="test[:20]" if options.debug else "test", ) @@ -339,6 +343,7 @@ class QuestionAnswering(Task): "squad", "squad", use_grammar=options.use_grammar, + no_early_stopping=options.no_early_stopping, validation_split=f"train[:{5 if options.debug else 200}]", test_split="validation[:20]" if options.debug else "validation", ) @@ -445,6 +450,7 @@ def get_task(name: str, evaluation_model: LLMModel, options: Namespace): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument("--use-grammar", "-g", action="store_true") +argument_group.add_argument("--no-early-stopping", action="store_true") argument_group.add_argument( - "--task", "-t", type=str, required=True, choices=["sa", "qa"] + "--task", "-t", type=str, required=True, choices=tasks.keys() ) -- GitLab