From bbce7ee4e972b1e9553a06185495ba79ce6f1b2a Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Thu, 22 Aug 2024 15:32:24 +0200 Subject: [PATCH] Consider debug mode for paraphrasing --- evoprompt/evolution.py | 2 +- evoprompt/optimization.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index 9e1a9f3..96b072f 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -98,7 +98,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): self.population_size = 3 num_iterations = 2 - self.init_run(self.population_size, num_iterations) + self.init_run(self.population_size, num_iterations, debug=debug) # Algorithm 1 Discrete prompt optimization: EVOPROMPT diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 4f41949..e6a9947 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -158,9 +158,11 @@ class PromptOptimization: ) return self.task.evaluate_validation(prompt, parent_histories) - def get_initial_prompts(self, num_initial_prompts: int): + def get_initial_prompts(self, num_initial_prompts: int, debug: bool = False): # this implements the para_topk algorothm from https://github.com/beeevita/EvoPrompt base_prompts = self.task.base_prompts + if debug: + base_prompts = base_prompts[:2] # evaluate all base prompts logger.info("Evaluating base prompts.") @@ -239,7 +241,7 @@ class PromptOptimization: def get_prompts(self, prompt_ids: list[str]): return [self.get_prompt(p_id) for p_id in prompt_ids] - def init_run(self, num_initial_prompts: int, num_iterations: int): + def init_run(self, num_initial_prompts: int, num_iterations: int, debug: bool = False): # family_tree contains the relation of prompts to its parents self.family_tree: dict[str, tuple[str, ...] | None] = {} # all_prompts contains a list of Prompt objects that took part in the optimization @@ -255,7 +257,7 @@ class PromptOptimization: self.save_snapshot() # the initial prompts - initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts) + initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts, debug=debug) self.P.append([]) for prompt, prompt_source in zip(initial_prompts, prompt_sources): prompt = self.add_prompt( -- GitLab