From 37d991ad49fae1026753cc7794e8e628fd3f9921 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de>
Date: Tue, 4 Jun 2024 18:20:14 +0200
Subject: [PATCH] Allow to generate unique paraphrases

---
 optimization.py | 21 +++++++++++++--------
 1 file changed, 13 insertions(+), 8 deletions(-)

diff --git a/optimization.py b/optimization.py
index 1cb0975..8adcafc 100644
--- a/optimization.py
+++ b/optimization.py
@@ -18,14 +18,17 @@ def paraphrase_prompts(
     model: LLMModel,
     prompt: str,
     n: int,
-    unique_prompts: bool = False,
-    num_tries: int = 10,
-    return_only_unique_prompts: bool = False,
+    unique_paraphrases: bool = False,
+    max_tries: int = 10,
+    return_only_unique_paraphrases: bool = False,
 ):
-    # TODO implement unique paraphrases
     total_usage = ModelUsage()
     paraphrases = []
-    for _ in range(n):
+    num_tries = 0
+    while len(paraphrases) < n:
+        if num_tries >= max_tries:
+            break
+        num_tries += 1
         paraphrase, usage = model(
             system_message=PARAPHRASE_PROMPT,
             prompt=prompt,
@@ -35,10 +38,12 @@ def paraphrase_prompts(
         total_usage += usage
         if "<prompt>" in paraphrase:
             paraphrase = paraphrase.split("<prompt>")[1].split("</prompt>")[0]
-        paraphrases.append(paraphrase)
-    if return_only_unique_prompts:
+        if not unique_paraphrases or paraphrase not in paraphrases:
+            # add paraphrase only if not already present if unique_paraphrases==True
+            paraphrases.append(paraphrase)
+    if return_only_unique_paraphrases:
         paraphrases = list(set(paraphrases))
-    return paraphrases, usage
+    return paraphrases, total_usage
 
 
 class PromptOptimization:
-- 
GitLab