diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index 328ab8227fdf38220e55e3d979c72b2eff7e30ee..d909578a8d08c1fe83419c6e2d42f3e41ef26758 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -7,7 +7,8 @@ from datasets import Dataset from evaluate import load as load_metric from llama_cpp import LlamaGrammar -from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task +from evoprompt.opt_types import ModelUsage +from evoprompt.task.task import DatasetDatum, Task logger = logging.getLogger(__name__) @@ -109,7 +110,8 @@ class QuestionAnswering(Task): prompt: str, dataset: Dataset, parent_histories: list[list[float]] | None = None, - ): + no_early_stopping: bool = False, + ) -> tuple[float, ModelUsage, list[float]]: def replace_symbol_for_grammar(sample: DatasetDatum): symbol_replacement_mapping = { "\u2013": "-", @@ -133,7 +135,7 @@ class QuestionAnswering(Task): if self.use_grammar: # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen) dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols") - return super().evaluate(prompt, dataset, parent_histories=parent_histories) + return super().evaluate(prompt, dataset, parent_histories, no_early_stopping) @property def metric_name(self): diff --git a/main.py b/main.py index 6a3d6b51016bb33b0ea0b2375e0f4b23350f4a12..531185c4f53cf64213e9b18722e7b47568617cec 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,8 @@ if __name__ == "__main__": raise ValueError( f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG." ) - logger.info("DEBUG mode: Do a quick run") + if debug: + logger.info("DEBUG mode: Do a quick run") # set up evolution model evolution_model = LLMModel.get_model(options.evolution_engine, options)