Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 12.26 KiB
from functools import lru_cache, partial
from pathlib import Path
from typing import DefaultDict, get_type_hints

from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from llama_cpp import Callable, Llama
from numpy.random import choice
from tqdm import tqdm, trange

from cli import argument_parser
from models import Llama2, OpenAI
from utils import (
    log_calls,
    logger,
    save_family_tree_visualization,
    save_genealogy,
    save_snapshot,
)
# whether to use chat model for LLM or not
USE_CHAT: bool = False

load_dotenv()

current_directory = Path(__file__).resolve().parent

CLASSIFICATION_PROMPT = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}.

### Input:
{input}

### Response:
"""

PARAPHRASE_PROMPT = """
Below is an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.

### Instruction:
{instruction}

### Response:
<prompt>
"""


@log_calls("Evaluating dataset")
def evaluate_prompt(prompt: str, dataset: Dataset):
    sst2_labels = {"negative": 0, "positive": 1}

    results: DefaultDict[str, int] = DefaultDict(int)
    dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)

    for datum in dataset_iterator:
        response = evaluation_model(
            prompt=CLASSIFICATION_PROMPT.format(instruction=prompt, input=datum["text"])
        )
        answer = response.lower()
        answer_label = None
        for label in sst2_labels.keys():
            if label in answer:
                answer_label = sst2_labels[label]
                break
        else:
            logger.warning(f"Invalid answer: {answer}")
            results["failed"] += 1
            continue

        classification_result = (
            "incorrect" if answer_label != datum["label"] else "correct"
        )
        results[classification_result] += 1
        dataset_iterator.set_postfix(results)

    accuracy = results["correct"] / sum(results.values())
    return accuracy


@log_calls("Paraphrasing prompts")
def paraphrase_prompts(prompt: str, n: int):
    paraphrases = []
    for _ in range(n):
        paraphrase = evolution_model(
            prompt=PARAPHRASE_PROMPT.format(instruction=prompt)
        )
        paraphrases.append(paraphrase)
    return paraphrases


@log_calls("Performing selection")
def selection(prompts, scores):
    # In GA, two parent solutions are normally selected based on the roulette wheel
    # selection method according to the fitness value (Lipowski & Lipowska, 2012).
    # Similar to this, we utilize the roulette wheel selection method to select
    # two parent prompts in the current population according to the scores evaluated
    # on development sets. Specifically, let si denote the performance score on the
    # development set of the i-th prompt in the population, which contains a total
    # of N prompts. The probability of selecting the i-th prompt as a parent can be expressed as
    # pi = si / Σj=1->N sj.
    if sum(scores) == 0:
        # sum of scores is 0 ==> each score is 0, draw with equal probability
        selection_probabilities = len(scores) * [1 / len(scores)]
    else:
        selection_probabilities = [score / sum(scores) for score in scores]
    return choice(prompts, size=2, replace=False, p=selection_probabilities)


GA_PROMPT = """
Please follow the instruction step-by-step to generate a better prompt.
1. Cross over the following prompts and generate a new prompt:
Prompt 1: {prompt1}
Prompt 2:{prompt2}
2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>.
"""


DE_PROMPT = """
Please follow the instruction step-by-step to generate a better prompt.
1. Identify the different parts between the Prompt 1 and Prompt 2:
Prompt 1: {prompt1}
Prompt 2: {prompt2}
2. Randomly mutate the different parts
3. Combine the different parts with Prompt 3, selectively replace it with the different parts in Step 2 and generate a new prompt.
Prompt 3: {prompt3}
4. Cross over the prompt in the Step 3 with the following basic prompt and generate a final prompt bracketed with <prompt> and </prompt>:
Basic Prompt: {basic_prompt}
"""


@log_calls("Performing prompt evolution using GA")
def evolution_ga(prompt1: str, prompt2: str):
    # Following the evolutionary operators in GA, a new candidate prompt is generated through
    # a two-step process based on the selected two parents:
    # 1) The parent prompts undergo crossover, resulting in a new prompt that
    #   selectively combines components from both parents;
    # 2) The newly generated prompt from the first step undergoes mutation,
    #   in which random alterations are made to some of its content.
    # Based on this two-step process, we design instructions, guiding LLMs to
    # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
    evolved_prompt = evolution_model(
        prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2)
    )
    if "<prompt>" in evolved_prompt:
        evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
    return evolved_prompt


@log_calls("Performing prompt evolution using DE")
def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str):
    # TODO add comment from paper
    evolved_prompt = evolution_model(
        prompt=DE_PROMPT.format(
            prompt1=prompt1,
            prompt2=prompt2,
            prompt3=best_prompt,
            basic_prompt=basic_prompt,
        )
    )
    if "<prompt>" in evolved_prompt:
        evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
    return evolved_prompt


@log_calls("Updating prompts")
def update(prompts: list[str], scores: list[float], N: int):
    # EVOPROMPT iteratively generates new candidate prompts and assesses each prompt
    # using a development set, denoted as D, to obtain a score that quantifies the
    # quality of the prompt. We consider a straightforward selection strategy.
    # Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts,
    # which are combined with the current population of N prompts.
    # The updated population is then selected by retaining the N prompts with the highest scores.
    retained_prompts = []
    retained_scores = []

    for prompt, score in zip(prompts, scores):
        if len(retained_prompts) < N:
            retained_prompts.append(prompt)
            retained_scores.append(score)
        elif score > min(retained_scores):
            min_index = retained_scores.index(min(retained_scores))
            retained_prompts[min_index] = prompt
            retained_scores[min_index] = score

    return retained_prompts, retained_scores


def run_episode(evo_alg_str: str):
    # Algorithm 1 Discrete prompt optimization: EVOPROMPT

    # Require:
    # - Size of population
    N = 10
    # - Initial prompts P0 = {p1, p2, . . . , pN }
    sst2_base_prompt = """In this task, you are given sentences from movie reviews. The task is to classify a sentence as "’positive’" if the sentiment of the sentence is positive or as "’negative’" if the sentiment of the sentence is negative. Return label only without any other text."""  #  from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
    paraphrases = paraphrase_prompts(sst2_base_prompt, n=N - 1)
    # the current population
    population = [sst2_base_prompt] + paraphrases

    # P keeps track of prompts and its generations
    P = [population]

    # - A dev set D
    # The size of the development set is 200.
    D = load_dataset("SetFit/sst2", split="validation[:200]")

    # - fD(·) denotes the score of a prompt on the desired LLM evaluated on D
    f_D = lru_cache(maxsize=None)(partial(evaluate_prompt, dataset=D))

    # - a pre-defined number of iterations T
    T = 10

    # - carefully designed evolutionary operators to generate a new prompt Evo(·)

    # Line 1: Initial evaluation scores: S0 ← {si = fD (pi )|i ∈ [1, N ]}
    # the current population's scores
    population_scores = [f_D(p) for p in P[0]]
    # S keeps track of scores
    S = [population_scores]

    # add initial prompts to family tree
    for prompt, score in zip(P[0], S[0]):
        # None marks that there is no parent
        family_tree[prompt] = None

    # evolution = EvolutionGA(num_evolutions=N)

    # Line 2:
    for t in trange(1, T + 1, desc="T", leave=True):
        # Line 3: Selection: select a certain number of prompts from current population as parent prompts
        # pr1,...,prk ∼ Pt−1
        if evo_alg_str == "de":
            # DE needs best prompt for evolution
            best_prompt_current_evolution = max(
                range(N), key=lambda i: population_scores[i]
            )

        # start new generation
        P.append([])
        S.append([])

        for i in trange(N, desc="N", leave=False):
            # for both GA and DE we start with two parent prompts
            pr1, pr2 = selection(population, population_scores)

            # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
            # p′i ←Evo(pr1,...,prk)
            if evo_alg_str == "ga":
                p_i = evolution_ga(pr1, pr2)
            elif evo_alg_str == "de":
                p_i = evolution_de(
                    pr1, pr2, population[i], best_prompt_current_evolution
                )

            # Line 5: Evaluation
            # s′_i ← f(p′i,D)
            s_i = f_D(p_i)

            P[t].append(p_i)
            S[t].append(s_i)

            # keep track of genealogy with score
            family_tree[p_i] = (pr1, pr2)

        # Line 6: Update based on the evaluation scores
        # Pt ← {Pt−1, p′i} and St ← {St−1, s′i}
        if evo_alg_str == "ga":
            # GA keeps N best prompts from current population and evolutions
            population, population_scores = update(
                population + P[t], population_scores + S[t], N
            )
        elif evo_alg_str == "de":
            # for DE we keep the evolved prompt if it is better than the basic prompt, and use the basic prompt otherwise
            population, population_scores = list(
                zip(
                    *[
                        (new_prompt, new_prompt_score)
                        if new_prompt_score > current_prompt_score
                        else (current_prompt, current_prompt_score)
                        for current_prompt, current_prompt_score, new_prompt, new_prompt_score in zip(
                            population, population_scores, P[t], S[t]
                        )
                    ]
                )
            )

    save_snapshot(family_tree, P, S, T, N)
    save_genealogy(family_tree, P, S, T)
    save_family_tree_visualization(family_tree, P, S, T + 1)
    # Line 8: Return the best prompt, p∗, among the final population PT :
    # p∗ ← argmaxp∈PT f(p, D)
    p = max(range(N), key=lambda i: S[T][i])
    logger.info(f"Best prompt: {P[T][p]}")

    # We pick the prompt with the highest score on the development set and report its score on the testset.
    test_D = load_dataset("SetFit/sst2", split="test")
    evaluate_prompt(P[p], test_D)


family_tree = {}
if __name__ == "__main__":
    options = argument_parser.parse_args()

    # set up evolution model
    match options.evolution_engine:
        case "llama2":
            logger.info("Using Llama2 client as the evolution engine")
            evolution_model = Llama2(
                str(current_directory / "models/llama-2-13b-chat.Q5_K_M.gguf"),
                chat=USE_CHAT,
            )

        case "openai":
            logger.info("Using OpenAI client as the evolution engine")
            evolution_model = OpenAI("gpt-3.5-turbo", chat=USE_CHAT)

    # set up evaluation model
    # NOTE currenty we always stick to Llama2 as evaluation model
    match options.evolution_engine:
        case "llama2":
            evaluation_model = evolution_model
        case "openai":
            evaluation_model = Llama2(
                str(current_directory / "models/llama-2-13b-chat.Q5_K_M.gguf"),
                chat=USE_CHAT,
            )

    run_episode(evo_alg_str=options.evolution_algorithm)