diff --git a/evoprompt/task/base_prompts_mixin.py b/evoprompt/task/base_prompts_mixin.py index 5c31c506d7543f7e6d4a5cde34bc6605a5e20abd..c4d00cd99d71788623ce4b6d1e67e5b5dfd14e60 100644 --- a/evoprompt/task/base_prompts_mixin.py +++ b/evoprompt/task/base_prompts_mixin.py @@ -5,6 +5,7 @@ import re from datasets import Dataset from evoprompt.models import LLMModel +from evoprompt.utils import get_rng class BasePromptsFromJsonMixin: @@ -35,17 +36,25 @@ class BasePromptsFromGeneration: self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False ) -> str: self.validation_dataset: Dataset - samples = self.validation_dataset._select_contiguous(0, 5) + samples = self.validation_dataset.shuffle(42).select( + get_rng().choice(len(self.validation_dataset), 5, replace=False) + ) prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n" + raise NotImplementedError( + "The prompt needs to be adapted for the model taking into account the correct format." + ) prompt = self.build_demonstration_prompt(samples, prompt=prompt) prompt += "\nThe instruction was " + system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs." + input(prompt) generated_prompts = [] while len(generated_prompts) < num_prompts: response, _, _, _ = self.evolution_model.create_completion( - system_message=f"You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs.", + system_message=system_message, prompt=prompt, ) + input(response) matches = re.findall( # regex that extracts anything within tags <instruction> and optional </instruction> rf"<instruction>(.+?)(?:(?=</instruction>)|$)",