From fc4f0be620b8df7ad568793c9a6ee99f07b556f6 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 3 Sep 2024 17:17:37 +0200 Subject: [PATCH] Implement generating base prompts --- evoprompt/helpers/prompts.py | 49 +++++++++++++++++++++++++++- evoprompt/task/question_answering.py | 18 +++++----- main.py | 15 ++++++++- 3 files changed, 72 insertions(+), 10 deletions(-) diff --git a/evoprompt/helpers/prompts.py b/evoprompt/helpers/prompts.py index df1c725..b73ed56 100644 --- a/evoprompt/helpers/prompts.py +++ b/evoprompt/helpers/prompts.py @@ -1,5 +1,10 @@ import json from pathlib import Path +import re + +from datasets import Dataset + +from evoprompt.models import LLMModel class BasePromptsFromJsonMixin: @@ -21,4 +26,46 @@ class BasePromptsFromJsonMixin: class BasePromptsFromGeneration: - pass + def __init__(self, *args, **kwargs) -> None: + self.evolution_model: LLMModel = kwargs.get("evolution_model") + super().__init__(*args, **kwargs) + + # this implements the initial population generation from Zhou et al., 2023: Large Language Models are Human-Level Prompt Engineers + def generate_prompt( + self, num_prompts: int, patience: int = 10, allow_duplicates: bool = False + ) -> str: + self.validation_dataset: Dataset + samples = self.validation_dataset._select_contiguous(0, 5) + prompt = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs:\n" + for sample in samples: + prompt += f"\n\n{self._get_prompt_text_for_datum(sample)}\n{self._get_generation_prefix()}{self._get_gold_label_generation_for_datum(sample)}\n" + prompt += "\nThe instruction was " + + generated_prompts = [] + while len(generated_prompts) < num_prompts: + response, _, _ = self.evolution_model.create_completion( + system_message=f"You are a helpful assistant. Please provide the instruction wrapped within tags <instruction> and </instruction> that belongs to the given input-output pairs.", + prompt=prompt, + ) + matches = re.findall( + # regex that extracts anything within tags <instruction> and optional </instruction> + rf"<instruction>(.+?)(?:(?=</instruction>)|$)", + response, + flags=re.IGNORECASE, + ) + if matches: + prompt = matches[-1].strip() + if allow_duplicates or prompt not in generated_prompts: + generated_prompts.append(matches[-1].strip()) + else: + if patience == 0: + break + if patience > 0: + patience -= 1 + return generated_prompts + + @property + def base_prompts(self): + num_prompts = getattr(self, "num_generated_base_prompts", 0) + + return self.generate_prompt(num_prompts) diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index 9cfa194..0634b3e 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -8,6 +8,7 @@ from datasets import Dataset from evaluate import load as load_metric from llama_cpp import LlamaGrammar +from evoprompt.helpers.prompts import BasePromptsFromGeneration from evoprompt.opt_types import ModelUsage from evoprompt.task.task import DatasetDatum, Task from evoprompt.utils import get_rng @@ -163,16 +164,10 @@ class QuestionAnswering(Task): def metric_name(self): return "f1" - @property - def base_prompts(self): - # TODO find good base prompts - return [ - """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" - ] - -class SQuAD(QuestionAnswering): +class SQuAD(BasePromptsFromGeneration, QuestionAnswering): shorthand = "squad" + num_generated_base_prompts = 10 def load_validation_set( self, validation_dataset: str | None, validation_split: str | None @@ -196,3 +191,10 @@ class SQuAD(QuestionAnswering): def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: return self._get_gold_label_for_datum(datum)["text"][0] + + @property + def base_prompts(self): + generated_base_prompts = super().base_prompts + return [ + "In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context." + ] + generated_base_prompts diff --git a/main.py b/main.py index 3b87b48..a7a0ca2 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,9 @@ import os from typing import Any from dotenv import load_dotenv +import wandb +import weave +from weave.trace.settings import UserSettings from evoprompt.cli import argument_parser from evoprompt.evolution import get_optimizer_class @@ -49,6 +52,11 @@ if __name__ == "__main__": "Judge engine cannot be 'llamachat' when evolution engine is 'llama'" ) + # init weave tracing and do not print call link + weave_settings = UserSettings(disabled=False, print_call_link=False) + weave.init(project_name="evoprompt", settings=weave_settings) + wandb.init(project="evoprompt") + # set up console logging and rnd setup_console_logger(verbosity_level=options.verbose) init_rng(options.seed) @@ -94,7 +102,12 @@ if __name__ == "__main__": evaluation_model = LLMModel.get_model(name=evaluation_model_name, options=options) logger.info(f"Using {evaluation_model_name} as the evaluation engine") - task = get_task(options.task, evaluation_model, **options.__dict__) + task = get_task( + options.task, + evaluation_model, + **options.__dict__, + evolution_model=evolution_model, + ) logger.info(f"Running with task {task.__class__.__name__}") logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm) -- GitLab