From c203cb792882d03c6b49e962d69f0154835687ec Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Mon, 2 Dec 2024 10:38:56 +0100 Subject: [PATCH] Consider limit for annotations per evolution --- evoprompt/evolution/evolution.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/evoprompt/evolution/evolution.py b/evoprompt/evolution/evolution.py index 3cb816e..d2e9bdd 100644 --- a/evoprompt/evolution/evolution.py +++ b/evoprompt/evolution/evolution.py @@ -43,7 +43,13 @@ SYSTEM_MESSAGE = "Please carefully follow the instruction step-by-step." class EvolutionAnnotationHandler: - def __init__(self, *, strategy, max_annotations=5, max_annotations_per_evolution=1): + def __init__( + self, + *, + strategy: str, + max_annotations: int = 5, + max_annotations_per_evolution: int = 1, + ): # TODO: we could also distribute the max_annotations over the different evolution steps self.annotation_handlers: dict[int, AnnotationHandler] = defaultdict( partial( @@ -65,10 +71,14 @@ class EvolutionAnnotationHandler: if self.strategy is None: # disable annotation return False - if len(self) < self.max_annotations: - if self.strategy == "simple": - # for the simple strategy, we do annotations until we reach the limit - return True + if self.strategy == "simple": + # for the simple strategy, we do annotations until we reach the limit, and optionally limit the number of annotations per evolution + if len(self) < self.max_annotations: + if self.max_annotations_per_evolution is None or ( + len(self.get_annotations_for_evolution(generation, update_step)) + < self.max_annotations_per_evolution + ): + return True return False def add_annotation( @@ -92,6 +102,13 @@ class EvolutionAnnotationHandler: return self.annotation_handlers[evolution_step].get_annotations() raise NotImplementedError(f"Strategy '{self.strategy}' is not implemented.") + def get_annotations_for_evolution(self, generation: int, update_step: int): + return [ + sample + for handler in self.annotation_handlers.values() + for sample in handler.annotated_samples[(generation, update_step)] + ] + class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): shorthand: str -- GitLab