Skip to content
Snippets Groups Projects
Commit bbce7ee4 authored by Max Kimmich's avatar Max Kimmich
Browse files

Consider debug mode for paraphrasing

parent 662a11e8
No related merge requests found
...@@ -98,7 +98,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): ...@@ -98,7 +98,7 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
self.population_size = 3 self.population_size = 3
num_iterations = 2 num_iterations = 2
self.init_run(self.population_size, num_iterations) self.init_run(self.population_size, num_iterations, debug=debug)
# Algorithm 1 Discrete prompt optimization: EVOPROMPT # Algorithm 1 Discrete prompt optimization: EVOPROMPT
......
...@@ -158,9 +158,11 @@ class PromptOptimization: ...@@ -158,9 +158,11 @@ class PromptOptimization:
) )
return self.task.evaluate_validation(prompt, parent_histories) return self.task.evaluate_validation(prompt, parent_histories)
def get_initial_prompts(self, num_initial_prompts: int): def get_initial_prompts(self, num_initial_prompts: int, debug: bool = False):
# this implements the para_topk algorothm from https://github.com/beeevita/EvoPrompt # this implements the para_topk algorothm from https://github.com/beeevita/EvoPrompt
base_prompts = self.task.base_prompts base_prompts = self.task.base_prompts
if debug:
base_prompts = base_prompts[:2]
# evaluate all base prompts # evaluate all base prompts
logger.info("Evaluating base prompts.") logger.info("Evaluating base prompts.")
...@@ -239,7 +241,7 @@ class PromptOptimization: ...@@ -239,7 +241,7 @@ class PromptOptimization:
def get_prompts(self, prompt_ids: list[str]): def get_prompts(self, prompt_ids: list[str]):
return [self.get_prompt(p_id) for p_id in prompt_ids] return [self.get_prompt(p_id) for p_id in prompt_ids]
def init_run(self, num_initial_prompts: int, num_iterations: int): def init_run(self, num_initial_prompts: int, num_iterations: int, debug: bool = False):
# family_tree contains the relation of prompts to its parents # family_tree contains the relation of prompts to its parents
self.family_tree: dict[str, tuple[str, ...] | None] = {} self.family_tree: dict[str, tuple[str, ...] | None] = {}
# all_prompts contains a list of Prompt objects that took part in the optimization # all_prompts contains a list of Prompt objects that took part in the optimization
...@@ -255,7 +257,7 @@ class PromptOptimization: ...@@ -255,7 +257,7 @@ class PromptOptimization:
self.save_snapshot() self.save_snapshot()
# the initial prompts # the initial prompts
initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts) initial_prompts, prompt_sources = self.get_initial_prompts(num_initial_prompts, debug=debug)
self.P.append([]) self.P.append([])
for prompt, prompt_source in zip(initial_prompts, prompt_sources): for prompt, prompt_source in zip(initial_prompts, prompt_sources):
prompt = self.add_prompt( prompt = self.add_prompt(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment