From 7a8cb23bf9d66fa75c384c2de35cf61c3e1155dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Sun, 27 Oct 2024 08:24:54 +0100 Subject: [PATCH] use simpler way to measure prompt length for shortest-first strategy --- evoprompt/task/task.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 3a6794b..c77e6c6 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -193,7 +193,8 @@ class ShortestFirstStrategy(EarlyStoppingStrategy): self, dataset: Dataset, parent_histories: ParentHistories | None ): sorted_dataset = sorted( - dataset, key=lambda x: len(self.task.build_prompt_input(x)) + dataset, + key=lambda x: len(self.task._get_prompt_text_for_datum(x)), ) return super().get_dataset_iterator(sorted_dataset, parent_histories) @@ -411,7 +412,9 @@ class Task(metaclass=ABCMeta): system_message, prompt_for_datum = self.build_prompt_input( datum, instruction, use_prediction_prefix=self.force_task_prediction_prefix ) - logger.debug(f"==== Task Prediction ====\nSystem message:\n\t{system_message}\nPrompt for datum:\n\t{prompt_for_datum}\nHistory:\n\t{history}") + logger.debug( + f"==== Task Prediction ====\nSystem message:\n\t{system_message}\nPrompt for datum:\n\t{prompt_for_datum}\nHistory:\n\t{history}" + ) response, _, _, usage = self.model.create_completion( system_message=system_message, messages=prompt_for_datum, -- GitLab