From 2f2fd458439d393883d645804677ac3b58c2f2d1 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Mon, 29 Jul 2024 14:42:27 +0200
Subject: [PATCH] Fix unique paraphrase generation

---
 evoprompt/optimization.py | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index 59b564a..7100a7a 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -43,9 +43,19 @@ def paraphrase_prompts(
         total_usage += usage
         if "<prompt>" in paraphrase:
             paraphrase = paraphrase.split("<prompt>")[1].split("</prompt>")[0]
-        if not unique_paraphrases or paraphrase not in paraphrases:
+        if (
+            not unique_paraphrases
+            or paraphrase not in paraphrases
+            or max_tries - num_tries == n - len(paraphrases)
+        ):
             # add paraphrase only if not already present if unique_paraphrases==True
             paraphrases.append(paraphrase)
+
+    assert len(paraphrases) == n, "Requested %d paraphrases, but %d were generated." % (
+        n,
+        len(paraphrases),
+    )
+
     if return_only_unique_paraphrases:
         paraphrases = list(set(paraphrases))
     return paraphrases, total_usage
@@ -156,7 +166,11 @@ class PromptOptimization:
             unique_paraphrases=True,
         )
         self.total_evolution_usage += paraphrase_usage
-        logger.info("Paraphrased prompt '%s': %s.", self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), paraphrases)
+        logger.info(
+            "Paraphrased prompt '%s': %s.",
+            self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"),
+            paraphrases,
+        )
 
         # the initial prompts
         initial_prompts = [self.task.base_prompt] + paraphrases
-- 
GitLab