From 2f2fd458439d393883d645804677ac3b58c2f2d1 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Mon, 29 Jul 2024 14:42:27 +0200 Subject: [PATCH] Fix unique paraphrase generation --- evoprompt/optimization.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 59b564a..7100a7a 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -43,9 +43,19 @@ def paraphrase_prompts( total_usage += usage if "<prompt>" in paraphrase: paraphrase = paraphrase.split("<prompt>")[1].split("</prompt>")[0] - if not unique_paraphrases or paraphrase not in paraphrases: + if ( + not unique_paraphrases + or paraphrase not in paraphrases + or max_tries - num_tries == n - len(paraphrases) + ): # add paraphrase only if not already present if unique_paraphrases==True paraphrases.append(paraphrase) + + assert len(paraphrases) == n, "Requested %d paraphrases, but %d were generated." % ( + n, + len(paraphrases), + ) + if return_only_unique_paraphrases: paraphrases = list(set(paraphrases)) return paraphrases, total_usage @@ -156,7 +166,11 @@ class PromptOptimization: unique_paraphrases=True, ) self.total_evolution_usage += paraphrase_usage - logger.info("Paraphrased prompt '%s': %s.", self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), paraphrases) + logger.info( + "Paraphrased prompt '%s': %s.", + self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), + paraphrases, + ) # the initial prompts initial_prompts = [self.task.base_prompt] + paraphrases -- GitLab