From 7a8cb23bf9d66fa75c384c2de35cf61c3e1155dd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Sun, 27 Oct 2024 08:24:54 +0100
Subject: [PATCH] use simpler way to measure prompt length for shortest-first
 strategy

---
 evoprompt/task/task.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 3a6794b..c77e6c6 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -193,7 +193,8 @@ class ShortestFirstStrategy(EarlyStoppingStrategy):
         self, dataset: Dataset, parent_histories: ParentHistories | None
     ):
         sorted_dataset = sorted(
-            dataset, key=lambda x: len(self.task.build_prompt_input(x))
+            dataset,
+            key=lambda x: len(self.task._get_prompt_text_for_datum(x)),
         )
         return super().get_dataset_iterator(sorted_dataset, parent_histories)
 
@@ -411,7 +412,9 @@ class Task(metaclass=ABCMeta):
         system_message, prompt_for_datum = self.build_prompt_input(
             datum, instruction, use_prediction_prefix=self.force_task_prediction_prefix
         )
-        logger.debug(f"==== Task Prediction ====\nSystem message:\n\t{system_message}\nPrompt for datum:\n\t{prompt_for_datum}\nHistory:\n\t{history}")
+        logger.debug(
+            f"==== Task Prediction ====\nSystem message:\n\t{system_message}\nPrompt for datum:\n\t{prompt_for_datum}\nHistory:\n\t{history}"
+        )
         response, _, _, usage = self.model.create_completion(
             system_message=system_message,
             messages=prompt_for_datum,
-- 
GitLab