From 29cb6b170940284eef087d5feaaf259c03ab4328 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 9 Aug 2024 06:38:30 +0200
Subject: [PATCH] refactor initial prompt generation

---
 evoprompt/optimization.py | 30 ++++++++++++++++--------------
 1 file changed, 16 insertions(+), 14 deletions(-)

diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index 0eecb0c..7b14eb7 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -94,6 +94,21 @@ class PromptOptimization:
         )
         return self.task.evaluate_validation(prompt, parent_histories)
 
+    def get_initial_prompts(self, num_initial_prompts: int):
+        paraphrases, paraphrase_usage = paraphrase_prompts(
+            self.evolution_model,
+            self.task.base_prompt,
+            n=num_initial_prompts - 1,
+            unique_paraphrases=True,
+        )
+        self.total_evolution_usage += paraphrase_usage
+        logger.info(
+            "Paraphrased prompt '%s': %s.",
+            self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"),
+            paraphrases,
+        )
+        return [self.task.base_prompt] + paraphrases
+
     def add_prompt(
         self,
         prompt: str,
@@ -159,21 +174,8 @@ class PromptOptimization:
         self.run_directory = initialize_run_directory(self.evolution_model)
         self.save_snapshot()
 
-        paraphrases, paraphrase_usage = paraphrase_prompts(
-            self.evolution_model,
-            self.task.base_prompt,
-            n=num_initial_prompts - 1,
-            unique_paraphrases=True,
-        )
-        self.total_evolution_usage += paraphrase_usage
-        logger.info(
-            "Paraphrased prompt '%s': %s.",
-            self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"),
-            paraphrases,
-        )
-
         # the initial prompts
-        initial_prompts = [self.task.base_prompt] + paraphrases
+        initial_prompts = self.get_initial_prompts(num_initial_prompts)
         initial_prompts = self.add_prompts(
             initial_prompts, metas=[{"gen": 0} for _ in initial_prompts]
         )
-- 
GitLab