From 9f5e25811e79f41488056a16da673230a5464961 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 1 Oct 2024 16:14:53 +0200
Subject: [PATCH] Update sample selection for base prompt generation and
 temporarily disable bsae prompt generation

---
 evoprompt/task/base_prompts_mixin.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/evoprompt/task/base_prompts_mixin.py b/evoprompt/task/base_prompts_mixin.py
index 5c31c50..c4d00cd 100644
--- a/evoprompt/task/base_prompts_mixin.py
+++ b/evoprompt/task/base_prompts_mixin.py
@@ -5,6 +5,7 @@ import re
 from datasets import Dataset
 
 from evoprompt.models import LLMModel
+from evoprompt.utils import get_rng
 
 
 class BasePromptsFromJsonMixin:
@@ -35,17 +36,25 @@ class BasePromptsFromGeneration:
         self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False
     ) -> str:
         self.validation_dataset: Dataset
-        samples = self.validation_dataset._select_contiguous(0, 5)
+        samples = self.validation_dataset.shuffle(42).select(
+            get_rng().choice(len(self.validation_dataset), 5, replace=False)
+        )
         prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n"
+        raise NotImplementedError(
+            "The prompt needs to be adapted for the model taking into account the correct format."
+        )
         prompt = self.build_demonstration_prompt(samples, prompt=prompt)
         prompt += "\nThe instruction was "
+        system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs."
+        input(prompt)
 
         generated_prompts = []
         while len(generated_prompts) < num_prompts:
             response, _, _, _ = self.evolution_model.create_completion(
-                system_message=f"You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs.",
+                system_message=system_message,
                 prompt=prompt,
             )
+            input(response)
             matches = re.findall(
                 # regex that extracts anything within tags <instruction> and optional </instruction>
                 rf"<instruction>(.+?)(?:(?=</instruction>)|$)",
-- 
GitLab