-
Max Kimmich authoredMax Kimmich authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 12.26 KiB
from functools import lru_cache, partial
from pathlib import Path
from typing import DefaultDict, get_type_hints
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from llama_cpp import Callable, Llama
from numpy.random import choice
from tqdm import tqdm, trange
from cli import argument_parser
from models import Llama2, OpenAI
from utils import (
log_calls,
logger,
save_family_tree_visualization,
save_genealogy,
save_snapshot,
)
# whether to use chat model for LLM or not
USE_CHAT: bool = False
load_dotenv()
current_directory = Path(__file__).resolve().parent
CLASSIFICATION_PROMPT = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}.
### Input:
{input}
### Response:
"""
PARAPHRASE_PROMPT = """
Below is an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.
### Instruction:
{instruction}
### Response:
<prompt>
"""
@log_calls("Evaluating dataset")
def evaluate_prompt(prompt: str, dataset: Dataset):
sst2_labels = {"negative": 0, "positive": 1}
results: DefaultDict[str, int] = DefaultDict(int)
dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
for datum in dataset_iterator:
response = evaluation_model(
prompt=CLASSIFICATION_PROMPT.format(instruction=prompt, input=datum["text"])
)
answer = response.lower()
answer_label = None
for label in sst2_labels.keys():
if label in answer:
answer_label = sst2_labels[label]
break
else:
logger.warning(f"Invalid answer: {answer}")
results["failed"] += 1
continue
classification_result = (
"incorrect" if answer_label != datum["label"] else "correct"
)
results[classification_result] += 1
dataset_iterator.set_postfix(results)
accuracy = results["correct"] / sum(results.values())
return accuracy
@log_calls("Paraphrasing prompts")
def paraphrase_prompts(prompt: str, n: int):
paraphrases = []
for _ in range(n):
paraphrase = evolution_model(
prompt=PARAPHRASE_PROMPT.format(instruction=prompt)
)
paraphrases.append(paraphrase)
return paraphrases
@log_calls("Performing selection")
def selection(prompts, scores):
# In GA, two parent solutions are normally selected based on the roulette wheel
# selection method according to the fitness value (Lipowski & Lipowska, 2012).
# Similar to this, we utilize the roulette wheel selection method to select
# two parent prompts in the current population according to the scores evaluated
# on development sets. Specifically, let si denote the performance score on the
# development set of the i-th prompt in the population, which contains a total
# of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
# pi = si / Σj=1->N sj.
if sum(scores) == 0:
# sum of scores is 0 ==> each score is 0, draw with equal probability
selection_probabilities = len(scores) * [1 / len(scores)]
else:
selection_probabilities = [score / sum(scores) for score in scores]
return choice(prompts, size=2, replace=False, p=selection_probabilities)
GA_PROMPT = """
Please follow the instruction step-by-step to generate a better prompt.
1. Cross over the following prompts and generate a new prompt:
Prompt 1: {prompt1}
Prompt 2:{prompt2}
2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>.
"""
DE_PROMPT = """
Please follow the instruction step-by-step to generate a better prompt.
1. Identify the different parts between the Prompt 1 and Prompt 2:
Prompt 1: {prompt1}
Prompt 2: {prompt2}
2. Randomly mutate the different parts
3. Combine the different parts with Prompt 3, selectively replace it with the different parts in Step 2 and generate a new prompt.
Prompt 3: {prompt3}
4. Cross over the prompt in the Step 3 with the following basic prompt and generate a final prompt bracketed with <prompt> and </prompt>:
Basic Prompt: {basic_prompt}
"""
@log_calls("Performing prompt evolution using GA")
def evolution_ga(prompt1: str, prompt2: str):
# Following the evolutionary operators in GA, a new candidate prompt is generated through
# a two-step process based on the selected two parents:
# 1) The parent prompts undergo crossover, resulting in a new prompt that
# selectively combines components from both parents;
# 2) The newly generated prompt from the first step undergoes mutation,
# in which random alterations are made to some of its content.
# Based on this two-step process, we design instructions, guiding LLMs to
# generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
evolved_prompt = evolution_model(
prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2)
)
if "<prompt>" in evolved_prompt:
evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
return evolved_prompt
@log_calls("Performing prompt evolution using DE")
def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str):
# TODO add comment from paper
evolved_prompt = evolution_model(
prompt=DE_PROMPT.format(
prompt1=prompt1,
prompt2=prompt2,
prompt3=best_prompt,
basic_prompt=basic_prompt,
)
)
if "<prompt>" in evolved_prompt:
evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
return evolved_prompt
@log_calls("Updating prompts")
def update(prompts: list[str], scores: list[float], N: int):
# EVOPROMPT iteratively generates new candidate prompts and assesses each prompt
# using a development set, denoted as D, to obtain a score that quantifies the
# quality of the prompt. We consider a straightforward selection strategy.
# Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts,
# which are combined with the current population of N prompts.
# The updated population is then selected by retaining the N prompts with the highest scores.
retained_prompts = []
retained_scores = []
for prompt, score in zip(prompts, scores):
if len(retained_prompts) < N:
retained_prompts.append(prompt)
retained_scores.append(score)
elif score > min(retained_scores):
min_index = retained_scores.index(min(retained_scores))
retained_prompts[min_index] = prompt
retained_scores[min_index] = score
return retained_prompts, retained_scores
def run_episode(evo_alg_str: str):
# Algorithm 1 Discrete prompt optimization: EVOPROMPT
# Require:
# - Size of population
N = 10
# - Initial prompts P0 = {p1, p2, . . . , pN }
sst2_base_prompt = """In this task, you are given sentences from movie reviews. The task is to classify a sentence as "’positive’" if the sentiment of the sentence is positive or as "’negative’" if the sentiment of the sentence is negative. Return label only without any other text.""" # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
paraphrases = paraphrase_prompts(sst2_base_prompt, n=N - 1)
# the current population
population = [sst2_base_prompt] + paraphrases
# P keeps track of prompts and its generations
P = [population]
# - A dev set D
# The size of the development set is 200.
D = load_dataset("SetFit/sst2", split="validation[:200]")
# - fD(·) denotes the score of a prompt on the desired LLM evaluated on D
f_D = lru_cache(maxsize=None)(partial(evaluate_prompt, dataset=D))
# - a pre-defined number of iterations T
T = 10
# - carefully designed evolutionary operators to generate a new prompt Evo(·)
# Line 1: Initial evaluation scores: S0 ← {si = fD (pi )|i ∈ [1, N ]}
# the current population's scores
population_scores = [f_D(p) for p in P[0]]
# S keeps track of scores
S = [population_scores]
# add initial prompts to family tree
for prompt, score in zip(P[0], S[0]):
# None marks that there is no parent
family_tree[prompt] = None
# evolution = EvolutionGA(num_evolutions=N)
# Line 2:
for t in trange(1, T + 1, desc="T", leave=True):
# Line 3: Selection: select a certain number of prompts from current population as parent prompts
# pr1,...,prk ∼ Pt−1
if evo_alg_str == "de":
# DE needs best prompt for evolution
best_prompt_current_evolution = max(
range(N), key=lambda i: population_scores[i]
)
# start new generation
P.append([])
S.append([])
for i in trange(N, desc="N", leave=False):
# for both GA and DE we start with two parent prompts
pr1, pr2 = selection(population, population_scores)
# Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
# p′i ←Evo(pr1,...,prk)
if evo_alg_str == "ga":
p_i = evolution_ga(pr1, pr2)
elif evo_alg_str == "de":
p_i = evolution_de(
pr1, pr2, population[i], best_prompt_current_evolution
)
# Line 5: Evaluation
# s′_i ← f(p′i,D)
s_i = f_D(p_i)
P[t].append(p_i)
S[t].append(s_i)
# keep track of genealogy with score
family_tree[p_i] = (pr1, pr2)
# Line 6: Update based on the evaluation scores
# Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
if evo_alg_str == "ga":
# GA keeps N best prompts from current population and evolutions
population, population_scores = update(
population + P[t], population_scores + S[t], N
)
elif evo_alg_str == "de":
# for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
population, population_scores = list(
zip(
*[
(new_prompt, new_prompt_score)
if new_prompt_score > current_prompt_score
else (current_prompt, current_prompt_score)
for current_prompt, current_prompt_score, new_prompt, new_prompt_score in zip(
population, population_scores, P[t], S[t]
)
]
)
)
save_snapshot(family_tree, P, S, T, N)
save_genealogy(family_tree, P, S, T)
save_family_tree_visualization(family_tree, P, S, T + 1)
# Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D)
p = max(range(N), key=lambda i: S[T][i])
logger.info(f"Best prompt: {P[T][p]}")
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_D = load_dataset("SetFit/sst2", split="test")
evaluate_prompt(P[p], test_D)
family_tree = {}
if __name__ == "__main__":
options = argument_parser.parse_args()
# set up evolution model
match options.evolution_engine:
case "llama2":
logger.info("Using Llama2 client as the evolution engine")
evolution_model = Llama2(
str(current_directory / "models/llama-2-13b-chat.Q5_K_M.gguf"),
chat=USE_CHAT,
)
case "openai":
logger.info("Using OpenAI client as the evolution engine")
evolution_model = OpenAI("gpt-3.5-turbo", chat=USE_CHAT)
# set up evaluation model
# NOTE currenty we always stick to Llama2 as evaluation model
match options.evolution_engine:
case "llama2":
evaluation_model = evolution_model
case "openai":
evaluation_model = Llama2(
str(current_directory / "models/llama-2-13b-chat.Q5_K_M.gguf"),
chat=USE_CHAT,
)
run_episode(evo_alg_str=options.evolution_algorithm)