From bbce7ee4e972b1e9553a06185495ba79ce6f1b2a Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Thu, 22 Aug 2024 15:32:24 +0200
Subject: [PATCH] Consider debug mode for paraphrasing

---
 evoprompt/evolution.py    | 2 +-
 evoprompt/optimization.py | 8 +++++---
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py
index 9e1a9f3..96b072f 100644
--- a/evoprompt/evolution.py
+++ b/evoprompt/evolution.py
@@ -98,7 +98,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
             self.population_size = 3
             num_iterations = 2
 
-        self.init_run(self.population_size, num_iterations)
+        self.init_run(self.population_size, num_iterations, debug=debug)
 
         # Algorithm 1 Discrete prompt optimization: EVOPROMPT
 
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index 4f41949..e6a9947 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -158,9 +158,11 @@ class PromptOptimization:
         )
         return self.task.evaluate_validation(prompt, parent_histories)
 
-    def get_initial_prompts(self, num_initial_prompts: int):
+    def get_initial_prompts(self, num_initial_prompts: int, debug: bool = False):
         # this implements the para_topk algorothm from https://github.com/beeevita/EvoPrompt
         base_prompts = self.task.base_prompts
+        if debug:
+            base_prompts = base_prompts[:2]
 
         # evaluate all base prompts
         logger.info("Evaluating base prompts.")
@@ -239,7 +241,7 @@ class PromptOptimization:
     def get_prompts(self, prompt_ids: list[str]):
         return [self.get_prompt(p_id) for p_id in prompt_ids]
 
-    def init_run(self, num_initial_prompts: int, num_iterations: int):
+    def init_run(self, num_initial_prompts: int, num_iterations: int, debug: bool = False):
         # family_tree contains the relation of prompts to its parents
         self.family_tree: dict[str, tuple[str, ...] | None] = {}
         # all_prompts contains a list of Prompt objects that took part in the optimization
@@ -255,7 +257,7 @@ class PromptOptimization:
         self.save_snapshot()
 
         # the initial prompts
-        initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts)
+        initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts, debug=debug)
         self.P.append([])
         for prompt, prompt_source in zip(initial_prompts, prompt_sources):
             prompt = self.add_prompt(
-- 
GitLab