Skip to content
Snippets Groups Projects
Commit 8527a161 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

added openai as evolution engine

parent 589cc378
No related branches found
No related tags found
No related merge requests found
runs/**
models/**
**/__pycache__/**
.env
cli.py 0 → 100644
from argparse import ArgumentParser
argument_parser = ArgumentParser()
argument_parser.add_argument(
"--evolution-engine", "-e", type=str, choices=["openai", "llama2"], default="llama2"
)
argument_parser.add_argument(
"--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga"
)
from functools import lru_cache, partial
from pathlib import Path
from typing import DefaultDict
from typing import DefaultDict, get_type_hints
from datasets import Dataset, load_dataset
from llama_cpp import Llama
from dotenv import load_dotenv
from llama_cpp import Callable, Llama
from numpy.random import choice
from openai import OpenAI
from tqdm import tqdm, trange
from cli import argument_parser
from utils import (
log_calls,
logger,
......@@ -15,16 +18,20 @@ from utils import (
save_snapshot,
)
load_dotenv()
current_directory = Path(__file__).resolve().parent
llm = Llama(
llama = Llama(
str(current_directory / "models/llama-2-13b-chat.Q5_K_M.gguf"),
chat_format="llama-2",
verbose=False,
n_gpu_layers=60,
n_threads=8,
n_ctx=2048,
n_ctx=4096,
)
llm: get_type_hints(llama.create_chat_completion)
CLASSIFICATION_PROMPT = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
......@@ -38,12 +45,13 @@ Below is an instruction that describes a task, paired with an input that provide
"""
PARAPHRASE_PROMPT = """
Below is an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction.
Below is an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.
### Instruction:
{instruction}
### Response:
<prompt>
"""
......@@ -54,10 +62,17 @@ def evaluate_prompt(prompt: str, dataset: Dataset):
dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
for datum in dataset_iterator:
response = llm.create_completion(
CLASSIFICATION_PROMPT.format(instruction=prompt, input=datum["text"]),
response = llama.create_chat_completion(
messages=[
{
"role": "user",
"content": CLASSIFICATION_PROMPT.format(
instruction=prompt, input=datum["text"]
),
}
],
)
answer = response["choices"][0]["text"].lower()
answer = response["choices"][0]["message"]["content"].lower()
answer_label = None
for label in sst2_labels.keys():
if label in answer:
......@@ -82,11 +97,16 @@ def evaluate_prompt(prompt: str, dataset: Dataset):
def paraphrase_prompts(prompt: str, n: int):
paraphrases = []
for _ in range(n):
response = llm.create_completion(
PARAPHRASE_PROMPT.format(instruction=prompt),
paraphrase = llm(
messages=[
{
"role": "user",
"content": PARAPHRASE_PROMPT.format(instruction=prompt),
}
],
stop="</prompt>",
max_tokens=200,
)
paraphrase = response["choices"][0]["text"]
paraphrases.append(paraphrase)
return paraphrases
......@@ -111,7 +131,19 @@ Please follow the instruction step-by-step to generate a better prompt.
Prompt 1: {prompt1}
Prompt 2:{prompt2}
2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>.
<prompt>
"""
DE_PROMPT = """
Please follow the instruction step-by-step to generate a better prompt.
1. Identify the different parts between the Prompt 1 and Prompt 2:
Prompt 1: {prompt1}
Prompt 2: {prompt2}
2. Randomly mutate the different parts
3. Combine the different parts with Prompt 3, selectively replace it with the different parts in Step 2 and generate a new prompt.
Prompt 3: {prompt3}
4. Cross over the prompt in the Step 3 with the following basic prompt and generate a final prompt bracketed with <prompt> and </prompt>:
Basic Prompt: {basic_prompt}
"""
......@@ -125,13 +157,18 @@ def evolution(prompt1: str, prompt2: str):
# in which random alterations are made to some of its content.
# Based on this two-step process, we design instructions, guiding LLMs to
# generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
response = llm.create_completion(
GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2),
evolved_prompt = llm(
messages=[
{
"role": "user",
"content": GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2),
}
],
max_tokens=None,
stop="</prompt>",
)
evolved_prompt_response = response["choices"][0]["text"]
evolved_prompt = evolved_prompt_response.split("</prompt>")[0]
if "<prompt>" in evolved_prompt:
evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
return evolved_prompt
......@@ -158,8 +195,7 @@ def update(prompts: list[str], scores: list[float], N: int):
return retained_prompts, retained_scores
family_tree = {}
if __name__ == "__main__":
def run_episode():
# Algorithm 1 Discrete prompt optimization: EVOPROMPT
# Require:
......@@ -221,3 +257,32 @@ if __name__ == "__main__":
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_D = load_dataset("SetFit/sst2", split="test")
evaluate_prompt(P[p], test_D)
family_tree = {}
if __name__ == "__main__":
options = argument_parser.parse_args()
if options.evolution_algorithm == "de":
raise NotImplementedError("DE is not implemented yet")
match options.evolution_engine:
case "llama2":
logger.info("Using Llama2 client as the evolution engine")
llm = lambda *args, **kwargs: llama.create_chat_completion(
*args,
**kwargs,
)["choices"][0]["message"]["content"]
case "openai":
logger.info("Using OpenAI client as the evolution engine")
openai_client = OpenAI()
llm = (
lambda *args, **kwargs: openai_client.chat.completions.create(
model="gpt-3.5-turbo",
*args,
**kwargs,
)
.choices[0]
.message.content
)
run_episode()
......@@ -3,3 +3,5 @@ datasets
llama-cpp-python
tqdm
graphviz
python-dotenv
openai
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment