From 9f5e25811e79f41488056a16da673230a5464961 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 1 Oct 2024 16:14:53 +0200 Subject: [PATCH] Update sample selection for base prompt generation and temporarily disable bsae prompt generation --- evoprompt/task/base_prompts_mixin.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/evoprompt/task/base_prompts_mixin.py b/evoprompt/task/base_prompts_mixin.py index 5c31c50..c4d00cd 100644 --- a/evoprompt/task/base_prompts_mixin.py +++ b/evoprompt/task/base_prompts_mixin.py @@ -5,6 +5,7 @@ import re from datasets import Dataset from evoprompt.models import LLMModel +from evoprompt.utils import get_rng class BasePromptsFromJsonMixin: @@ -35,17 +36,25 @@ class BasePromptsFromGeneration: self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False ) -> str: self.validation_dataset: Dataset - samples = self.validation_dataset._select_contiguous(0, 5) + samples = self.validation_dataset.shuffle(42).select( + get_rng().choice(len(self.validation_dataset), 5, replace=False) + ) prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n" + raise NotImplementedError( + "The prompt needs to be adapted for the model taking into account the correct format." + ) prompt = self.build_demonstration_prompt(samples, prompt=prompt) prompt += "\nThe instruction was " + system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs." + input(prompt) generated_prompts = [] while len(generated_prompts) < num_prompts: response, _, _, _ = self.evolution_model.create_completion( - system_message=f"You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs.", + system_message=system_message, prompt=prompt, ) + input(response) matches = re.findall( # regex that extracts anything within tags <instruction> and optional </instruction> rf"<instruction>(.+?)(?:(?=</instruction>)|$)", -- GitLab