From 29cb6b170940284eef087d5feaaf259c03ab4328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Fri, 9 Aug 2024 06:38:30 +0200 Subject: [PATCH] refactor initial prompt generation --- evoprompt/optimization.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 0eecb0c..7b14eb7 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -94,6 +94,21 @@ class PromptOptimization: ) return self.task.evaluate_validation(prompt, parent_histories) + def get_initial_prompts(self, num_initial_prompts: int): + paraphrases, paraphrase_usage = paraphrase_prompts( + self.evolution_model, + self.task.base_prompt, + n=num_initial_prompts - 1, + unique_paraphrases=True, + ) + self.total_evolution_usage += paraphrase_usage + logger.info( + "Paraphrased prompt '%s': %s.", + self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), + paraphrases, + ) + return [self.task.base_prompt] + paraphrases + def add_prompt( self, prompt: str, @@ -159,21 +174,8 @@ class PromptOptimization: self.run_directory = initialize_run_directory(self.evolution_model) self.save_snapshot() - paraphrases, paraphrase_usage = paraphrase_prompts( - self.evolution_model, - self.task.base_prompt, - n=num_initial_prompts - 1, - unique_paraphrases=True, - ) - self.total_evolution_usage += paraphrase_usage - logger.info( - "Paraphrased prompt '%s': %s.", - self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), - paraphrases, - ) - # the initial prompts - initial_prompts = [self.task.base_prompt] + paraphrases + initial_prompts = self.get_initial_prompts(num_initial_prompts) initial_prompts = self.add_prompts( initial_prompts, metas=[{"gen": 0} for _ in initial_prompts] ) -- GitLab