Skip to content
Snippets Groups Projects
Commit b77a62b7 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

added parent based early stopping monitor

parent 28441f6d
No related branches found
No related tags found
No related merge requests found
......@@ -169,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization):
logger.info(f"Best prompt: {p}")
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_performance, _ = self.task.evaluate_test(p.content)
test_performance, _, _ = self.task.evaluate_test(p.content)
logger.info("Best prompt on test set: %s", test_performance)
logger.info(
"Usage (evolution model / evaluation model / total): %s / %s / %s",
......
......@@ -30,6 +30,7 @@ class Prompt:
content: str
score: float
usage: ModelUsage
evaluation_history: list[float]
meta: dict = field(default_factory=dict)
id: str = field(default_factory=lambda: uuid4().hex)
......
......@@ -58,22 +58,36 @@ class PromptOptimization:
def reset(self):
self._init
def evaluate_prompt(self, prompt: str):
return self.task.evaluate_validation(prompt)
def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None):
parent_histories = (
[parent.evaluation_history for parent in parents]
if parents is not None
else None
)
return self.task.evaluate_validation(prompt, parent_histories)
def add_prompt(
self, prompt: str, parents: tuple[Prompt] = None, meta: dict = None
self,
prompt: str,
parents: tuple[Prompt] | None = None,
meta: dict | None = None,
) -> Prompt:
score, usage = self.evaluate_prompt(prompt)
prompt = Prompt(content=prompt, score=score, meta=meta, usage=usage)
score, usage, history = self.evaluate_prompt(prompt, parents)
prompt_object = Prompt(
content=prompt,
score=score,
meta=meta if meta is not None else {},
usage=usage,
evaluation_history=history,
)
# keep track of prompt
self.all_prompts[prompt.id] = prompt
self.family_tree[prompt.id] = (
self.all_prompts[prompt_object.id] = prompt_object
self.family_tree[prompt_object.id] = (
tuple(p.id for p in parents) if parents is not None else None
)
return prompt
return prompt_object
def add_prompts(
self,
......
......@@ -23,7 +23,14 @@ You are given an instruction that describes a task, paired with an input that pr
DatasetDatum = dict
class MomentBasedStopping:
class EarlyStoppingMonitor:
@abstractmethod
def update(self, score: float) -> bool:
raise NotImplementedError
class MomentBasedStopping(EarlyStoppingMonitor):
"""
Watch the first derivative (moment) of the metric to determine when to stop.
"""
......@@ -45,18 +52,60 @@ class MomentBasedStopping:
def update(self, score: float) -> bool:
# caclulate the current moment (dx/dt)
self.num_calls += 1
if self.num_calls < self.start_after:
return False
self.moment_magnitudes.append(abs(score - self.last_score))
self.last_score = score
if len(self.moment_magnitudes) < self.patience:
return False
if mean(self.moment_magnitudes) < self.min_moment_magnitude:
return True
return False
class ParentBaselineBasedStopping(EarlyStoppingMonitor):
def __init__(
self,
parent_histories: list[list[float]],
*,
patience: int = 10,
start_after: int = 20,
min_improvement: float = 0.001,
):
self.parent_histories = parent_histories
self.patience = patience
self.start_after = start_after
self.min_improvement = min_improvement
self.num_calls = 0
self.improvement_memory = deque(maxlen=patience)
def update(self, score: float) -> bool:
self.num_calls += 1
if self.num_calls < self.start_after:
return False
if (
self.num_calls < self.start_after
or len(self.moment_magnitudes) < self.patience
):
parent_values = [ # get the metric value of the parents at the current step
(
parent_history[self.num_calls - 1]
if len(parent_history) >= self.num_calls
else parent_history[-1] # extend with last value
)
for parent_history in self.parent_histories
]
self.improvement_memory.append(
score - max(parent_values) # compare with the best parent
)
if len(self.improvement_memory) < self.patience:
return False
print(mean(self.moment_magnitudes))
if mean(self.moment_magnitudes) < self.min_moment_magnitude:
if max(self.improvement_memory) < self.min_improvement:
# if the highest improvement is less than the minimum improvement, we stop
return True
return False
......@@ -101,16 +150,31 @@ class Task:
def _aggregate_result(self, results: list) -> float:
pass
def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]:
early_stopping = MomentBasedStopping(
patience=len(dataset) // 20,
start_after=len(dataset) // 5,
)
def evaluate(
self,
prompt: str,
dataset: Dataset,
parent_histories: list[list[float]] | None = None,
) -> tuple[float, ModelUsage, list[float]]:
early_stopping: EarlyStoppingMonitor
early_stopping_params = {
"patience": max(len(dataset) // 20, 5),
"start_after": max(len(dataset) // 5, 5),
}
if parent_histories is not None:
early_stopping = ParentBaselineBasedStopping(
parent_histories, **early_stopping_params
)
else:
early_stopping = MomentBasedStopping(**early_stopping_params)
results: list = []
dataset_iterator: tqdm[DatasetDatum] = tqdm(
dataset, desc="evaluating prompt", leave=False
)
evaluation_usage = ModelUsage()
evaluation_history = []
for datum in dataset_iterator:
result, usage = self._evaluate_sample(prompt, datum)
......@@ -120,18 +184,20 @@ class Task:
{self.metric_name: f"{current_metric*100:.1f}%"}
)
evaluation_usage += usage
evaluation_history.append(current_metric)
if early_stopping.update(current_metric):
logger.info(
f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
)
break
return self._aggregate_result(results), evaluation_usage
return self._aggregate_result(results), evaluation_usage, evaluation_history
@log_calls("Evaluating validation dataset")
@lru_cache(maxsize=None)
def evaluate_validation(self, prompt: str):
return self.evaluate(prompt, self.validation_dataset)
def evaluate_validation(
self, prompt: str, parent_histories: list[list[float]] | None = None
):
return self.evaluate(prompt, self.validation_dataset, parent_histories)
@log_calls("Evaluating test dataset")
def evaluate_test(self, prompt: str):
......@@ -207,7 +273,7 @@ class SentimentAnalysis(Task):
break
else:
logger.warning(f"Invalid answer: {response}")
return "failed"
return "failed", usage
classification_result = (
"incorrect" if answer_label != datum["label"] else "correct"
......@@ -333,7 +399,7 @@ class QuestionAnswering(Task):
def _aggregate_result(self, results: list[float]) -> float:
return sum(results) / len(results)
def evaluate(self, prompt: str, dataset: Dataset) -> tuple[float, ModelUsage]:
def evaluate(self, prompt: str, dataset: Dataset):
def replace_symbol_for_grammar(sample: DatasetDatum):
symbol_replacement_mapping = {
"\u2013": "-",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment