Skip to content
Snippets Groups Projects

Use correct format for demonstration samples for evaluation and evolution

Merged Max Kimmich requested to merge refactor-models into master
11 files
+ 92
57
Compare changes
  • Side-by-side
  • Inline
Files
11
@@ -10,7 +10,7 @@ from evoprompt.utils import get_rng
class BasePromptsFromJsonMixin:
@staticmethod
def _load_json_file(path: str):
def _load_json_file(path: str) -> list[str]:
with Path(path).open() as json_file:
return json.load(json_file)
@@ -20,41 +20,48 @@ class BasePromptsFromJsonMixin:
raise Exception(
f"Class {self.__class__} does not exhibit attribute `base_prompts_files` which is needed for `BasePromptsFromJsonMixin`."
)
base_prompts = []
prompts, sources = super().base_prompts
prompts_from_files = []
for prompt_file in self.base_prompts_files:
base_prompts += self._load_json_file(prompt_file)
return base_prompts
prompts_from_files += self._load_json_file(prompt_file)
prompts += prompts_from_files
sources += ["baseprompt_file"] * len(prompts_from_files)
return prompts, sources
class BasePromptsFromGeneration:
class BasePromptsFromGenerationMixin:
def __init__(self, *args, **kwargs) -> None:
self.evolution_model: LLMModel = kwargs.get("evolution_model")
super().__init__(*args, **kwargs)
# this implements the initial population generation from Zhou et al., 2023: Large Language Models are Human-Level Prompt Engineers
# patience allows to stop the generation process if no new prompts can be generated
# can be set to -1 to generate as many prompts as needed (but can possibly run forever)
def generate_prompt(
self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False
) -> str:
) -> list[str]:
self.validation_dataset: Dataset
samples = self.validation_dataset.shuffle(42).select(
get_rng().choice(len(self.validation_dataset), 5, replace=False)
)
prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n"
raise NotImplementedError(
"The prompt needs to be adapted for the model taking into account the correct format."
prompt = "I gave a friend a single instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n\n"
prompt += "\n".join(
f"Input:\n{self._get_prompt_text_for_datum(sample)}\nOutput:\n{self._get_gold_label_generation_for_datum(sample)}\n"
for sample in samples
)
prompt = self.build_demonstration_prompt(samples, prompt=prompt)
prompt += "\nThe instruction was "
system_message = "You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs."
input(prompt)
messages = [
self.evolution_model._get_user_message(prompt)
] # , self.evolution_model._get_assistant_message("The instruction was ")]
generated_prompts = []
while len(generated_prompts) < num_prompts:
response, _, _, _ = self.evolution_model.create_completion(
system_message=system_message,
prompt=prompt,
messages=messages,
use_randomness=True,
)
input(response)
matches = re.findall(
# regex that extracts anything within tags <instruction> and optional </instruction>
rf"<instruction>(.+?)(?:(?=</instruction>)|$)",
@@ -62,9 +69,9 @@ class BasePromptsFromGeneration:
flags=re.IGNORECASE,
)
if matches:
prompt = matches[-1].strip()
if allow_duplicates or prompt not in generated_prompts:
generated_prompts.append(matches[-1].strip())
generated_prompt = matches[-1].strip()
if allow_duplicates or generated_prompt not in generated_prompts:
generated_prompts.append(generated_prompt)
else:
if patience == 0:
break
@@ -74,6 +81,15 @@ class BasePromptsFromGeneration:
@property
def base_prompts(self):
num_prompts = getattr(self, "num_generated_base_prompts", 0)
if not hasattr(self, "num_generated_base_prompts"):
raise AttributeError(
f"{self.__class__} must expose attribute `num_generated_base_prompts`"
)
prompts, sources = super().base_prompts
num_prompts = self.num_generated_base_prompts
generated_prompts = self.generate_prompt(num_prompts)
prompts += generated_prompts
sources += ["baseprompt_gen"] * len(generated_prompts)
return self.generate_prompt(num_prompts)
return prompts, sources
Loading