Skip to content
Snippets Groups Projects
Commit eb8df827 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

add option to disable early stopping

parent f73cc390
No related branches found
No related tags found
No related merge requests found
......@@ -68,12 +68,9 @@ class EvolutionAlgorithm(PromptOptimization):
# development set of the i-th prompt in the population, which contains a total
# of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
# pi = si / Σj=1->N sj.
scores = [prompt.score for prompt in prompts]
if sum(scores) == 0:
# sum of scores is 0 ==> each score is 0, draw with equal probability
selection_probabilities = len(scores) * [1 / len(scores)]
else:
selection_probabilities = [score / sum(scores) for score in scores]
# add small value to avoid zero chance of selection for some prompts
scores = [prompt.score + 1e-6 for prompt in prompts]
selection_probabilities = [score / sum(scores) for score in scores]
return choice(prompts, size=2, replace=False, p=selection_probabilities)
@abstractmethod
......
......@@ -5,13 +5,14 @@ from functools import lru_cache
from statistics import mean
from typing import Union
from cli import argument_parser
from datasets import Dataset, load_dataset
from evaluate import load as load_metric
from llama_cpp import LlamaGrammar, deque
from tqdm import tqdm
from cli import argument_parser
from models import Llama2, LLMModel, OpenAI
from opt_types import ModelUsage
from tqdm import tqdm
from utils import log_calls, logger
SYSTEM_MESSAGE = """
......@@ -121,6 +122,7 @@ class Task:
test_dataset: str,
*,
use_grammar: bool,
no_early_stopping: bool = False,
validation_split: str | None = None,
test_split: str | None = None,
) -> None:
......@@ -128,6 +130,7 @@ class Task:
# whether we use the grammar to constrain the model output or not
self.use_grammar = use_grammar
self.no_early_stopping = no_early_stopping
self.validation_dataset = load_dataset(
validation_dataset, split=validation_split
......@@ -183,7 +186,7 @@ class Task:
)
evaluation_usage += usage
evaluation_history.append(current_metric)
if early_stopping.update(current_metric):
if not self.no_early_stopping and early_stopping.update(current_metric):
logger.info(
f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
)
......@@ -234,6 +237,7 @@ class SentimentAnalysis(Task):
validation_dataset="SetFit/sst2",
test_dataset="SetFit/sst2",
use_grammar=options.use_grammar,
no_early_stopping=options.no_early_stopping,
validation_split=f"validation[:{5 if options.debug else 200}]",
test_split="test[:20]" if options.debug else "test",
)
......@@ -339,6 +343,7 @@ class QuestionAnswering(Task):
"squad",
"squad",
use_grammar=options.use_grammar,
no_early_stopping=options.no_early_stopping,
validation_split=f"train[:{5 if options.debug else 200}]",
test_split="validation[:20]" if options.debug else "validation",
)
......@@ -445,6 +450,7 @@ def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_group = argument_parser.add_argument_group("Task arguments")
argument_group.add_argument("--use-grammar", "-g", action="store_true")
argument_group.add_argument("--no-early-stopping", action="store_true")
argument_group.add_argument(
"--task", "-t", type=str, required=True, choices=["sa", "qa"]
"--task", "-t", type=str, required=True, choices=tasks.keys()
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment