diff --git a/cli.py b/cli.py index ab05d658a20c4392677707429d9265eed26cb2e0..071c220d14580eeed8cbf0263a894ddd80790ee8 100644 --- a/cli.py +++ b/cli.py @@ -1,14 +1,31 @@ from argparse import ArgumentParser +from models import MODELS + +# from typing import Literal + +# import pydantic + +# class Arguments(pydantic.BaseModel): +# # Required Args +# evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine") +# evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm") +# # model_path: str = pydantic.Field(description="the model path") +# task: Literal['sa', 'qa'] = pydantic.Field(description="the model path") +# use_grammar: bool = pydantic.Field(False, description="use grammar") +# debug: bool = pydantic.Field(None, description="use grammar") +# chat: bool = pydantic.Field(False, description="use chat mode") + + argument_parser = ArgumentParser() argument_parser.add_argument( - "--evolution-engine", "-e", type=str, choices=["openai", "llama2"], default="llama2" + "--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2" ) argument_parser.add_argument( "--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga" ) -argument_parser.add_argument("--model-path", "-m", type=str, required=True) +argument_parser.add_argument("--model", "-m", type=str) argument_parser.add_argument( "--task", "-t", type=str, required=True, choices=["sa", "qa"] ) diff --git a/evolution.py b/evolution.py index 9a935fac9ede1c7056f65cd42836835f607100ef..200a1b5bdd261c0190898bcb026a540b7dc37a16 100644 --- a/evolution.py +++ b/evolution.py @@ -3,10 +3,10 @@ from abc import abstractmethod from models import LLMModel from numpy.random import choice from opt_types import ModelUsage, Prompt -from optimization import PromptOptimization +from optimization import PromptOptimization, save_snapshot from task import Task from tqdm import trange -from utils import initialize_run_directory, log_calls, logger, save_snapshot +from utils import initialize_run_directory, log_calls, logger SYSTEM_MESSAGE = ( "Please follow the instruction step-by-step to generate a better prompt." @@ -100,7 +100,9 @@ class EvolutionAlgorithm(PromptOptimization): run_directory = initialize_run_directory(self.evolution_model) - initial_prompts, evolution_usage, evaluation_usage = self.init_run(self.population_size) + initial_prompts, evolution_usage, evaluation_usage = self.init_run( + self.population_size + ) total_evaluation_usage += evaluation_usage total_evolution_usage += evolution_usage @@ -166,7 +168,12 @@ class EvolutionAlgorithm(PromptOptimization): # We pick the prompt with the highest score on the development set and report its score on the testset. test_performance, _ = self.task.evaluate_test(p.content) logger.info("Best prompt on test set: %s", test_performance) - logger.info("Usage (evolution model / evaluation model / total): %s / %s / %s", total_evolution_usage, total_evaluation_usage, total_evolution_usage + total_evaluation_usage) + logger.info( + "Usage (evolution model / evaluation model / total): %s / %s / %s", + total_evolution_usage, + total_evaluation_usage, + total_evolution_usage + total_evaluation_usage, + ) return total_evolution_usage, total_evaluation_usage diff --git a/main.py b/main.py index 2095276a34e462f8f4489e69a9d0983247964081..4a139592384d88bce4fb0da3786aff2718a9143d 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from typing import Any from cli import argument_parser from dotenv import load_dotenv from evolution import DifferentialEvolution, GeneticAlgorithm -from models import Llama2, OpenAI +from models import Llama2, get_model_init from task import QuestionAnswering, SentimentAnalysis from utils import logger @@ -25,26 +25,29 @@ if __name__ == "__main__": options = argument_parser.parse_args() # set up evolution model + model_init_fn = get_model_init(options.evolution_engine) + evolution_model = model_init_fn( + options.model, + chat=options.chat, + ) + match options.evolution_engine: case "llama2": logger.info("Using Llama2 client as the evolution engine") - evolution_model = Llama2( - model_path=options.model_path, - chat=options.chat, - ) - case "openai": logger.info("Using OpenAI client as the evolution engine") - evolution_model = OpenAI("gpt-3.5-turbo", chat=options.chat) + # set up evaluation model - # NOTE currenty we always stick to Llama2 as evaluation model + # NOTE currenty we always stick to Llama2 as evaluation engine + # TODO allow to set separate engine and model for evaluation? + logger.info("Using Llama2 client as the evaluation engine") match options.evolution_engine: case "llama2": evaluation_model = evolution_model case "openai": evaluation_model = Llama2( - model_path=options.model_path, + model_path=options.model, chat=options.chat, ) diff --git a/models.py b/models.py index 486801546244e0c28ce8de309a91a8f6ff866125..f6f9fd6afa94004914f98e4db8f003df25f044a6 100644 --- a/models.py +++ b/models.py @@ -1,6 +1,6 @@ from abc import abstractmethod from pathlib import Path -from typing import Any +from typing import Any, Callable import openai from llama_cpp import Llama @@ -9,6 +9,31 @@ from opt_types import ModelUsage current_directory = Path(__file__).resolve().parent +MODELS = {} + + +def register_model(name: str): + def wrapper(f: Callable): + global MODELS + if name in MODELS: + raise ValueError("Cannot register model class %s: already exists", name) + MODELS[name] = f + + return f + + return wrapper + + +def get_models(): + return MODELS + + +def get_model_init(name: str): + if name not in MODELS: + raise ValueError("Model %s does not exist", name) + return MODELS[name] + + class LLMModel: chat: bool model: Any @@ -35,6 +60,7 @@ class LLMModel: pass +@register_model("llama2") class Llama2(LLMModel): """Loads and queries a Llama2 model.""" @@ -117,11 +143,16 @@ class Llama2(LLMModel): return response_text, usage +@register_model("openai") class OpenAI(LLMModel): """Queries an OpenAI model using its API.""" def __init__( - self, model: str, chat: bool = False, verbose: bool = False, **kwargs + self, + model: str = "gpt-3.5-turbo", + chat: bool = False, + verbose: bool = False, + **kwargs, ) -> None: super().__init__(chat, model) diff --git a/optimization.py b/optimization.py index d379907ecca62cdb5ccc01b1af48b5cf2d90f07a..9f18365c07f93fedd8753f5a082bb6dff17e11d0 100644 --- a/optimization.py +++ b/optimization.py @@ -1,7 +1,10 @@ +import json from itertools import zip_longest +from pathlib import Path +from typing import Any -from models import LLMModel -from opt_types import ModelUsage, Prompt +from models import Llama2, LLMModel, OpenAI +from opt_types import ModelUsage, OptTypeEncoder, Prompt from task import Task from utils import log_calls @@ -89,7 +92,9 @@ class PromptOptimization: def get_prompts(self, prompt_ids: list[str]): return [self.get_prompt(p_id) for p_id in prompt_ids] - def init_run(self, num_initial_prompts: int) -> tuple[list[Prompt], ModelUsage, ModelUsage]: + def init_run( + self, num_initial_prompts: int + ) -> tuple[list[Prompt], ModelUsage, ModelUsage]: # - Initial prompts P0 = {p1, p2, . . . , pN } paraphrases, paraphrase_usage = paraphrase_prompts( self.evolution_model, self.task.base_prompt, n=num_initial_prompts - 1 @@ -107,3 +112,58 @@ class PromptOptimization: evaluation_usage += prompt.usage return initial_prompts, paraphrase_usage, evaluation_usage + + +# TODO turn snapshots methods into instance methods of optimizer +def save_snapshot( + run_directory: Path, + all_prompts: list[Prompt], + family_tree: dict[str, tuple[str, str] | None], + P: list[list[str]], + T: int, + N: int, + task, + model: Llama2 | OpenAI, + evaluation_usage: ModelUsage, + evolution_usage: ModelUsage, + run_options: dict[str, Any], +): + + with open(run_directory / "snapshot.json", "w") as f: + json.dump( + { + "all_prompts": all_prompts, + "family_tree": family_tree, + "P": P, + "T": T, + "N": N, + "task": { + "name": task.__class__.__name__, + "validation_dataset": task.validation_dataset.info.dataset_name, + "test_dataset": task.test_dataset.info.dataset_name, + "metric": task.metric_name, + "use_grammar": task.use_grammar, + }, + "model": {"name": model.__class__.__name__}, + "evaluation_usage": evaluation_usage, + "evolution_usage": evolution_usage, + "run_options": run_options, + }, + f, + indent=4, + cls=OptTypeEncoder, + ) + + +def load_snapshot(path: Path): + import json + + with path.open("r") as f: + snapshot = json.load(f) + return ( + snapshot["family_tree"], + snapshot["P"], + snapshot["S"], + snapshot["T"], + snapshot["N"], + ) diff --git a/task.py b/task.py index 3c2dc0eecc2d065a6430c1290f1033a925c7206b..461348d35be3b41dbb79986ca5df87e47640056b 100644 --- a/task.py +++ b/task.py @@ -8,8 +8,9 @@ from datasets import Dataset, load_dataset from evaluate import load as load_metric from llama_cpp import LlamaGrammar from models import Llama2, OpenAI +from opt_types import ModelUsage from tqdm import tqdm -from utils import ModelUsage, log_calls, logger +from utils import log_calls, logger SYSTEM_MESSAGE = """ You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. diff --git a/utils.py b/utils.py index 3f4781de4f347fda537b1ceae42b5f7287405a59..c6c014bc5eaf766535fc452c89fefab3db2785c4 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,4 @@ import inspect -import json import logging import os import re @@ -10,9 +9,6 @@ from textwrap import dedent, indent from typing import Any, Callable from uuid import uuid4 -from models import Llama2, OpenAI -from opt_types import ModelUsage, OptTypeEncoder, Prompt - current_directory = Path(__file__).resolve().parent logger = logging.getLogger("test-classifier") logger.setLevel(level=logging.DEBUG) @@ -36,7 +32,7 @@ Only return the name without any text before or after.""".strip() RUNS_DIR = current_directory / "runs" -def initialize_run_directory(model: OpenAI | Llama2): +def initialize_run_directory(model: Callable): response, usage = model(None, run_name_prompt, chat=True) model.usage -= usage run_name_match = re.search(r"^\w+$", response, re.MULTILINE) @@ -126,57 +122,3 @@ class log_calls: arguments[argument_name] = value return arguments - - -def save_snapshot( - run_directory: Path, - all_prompts: list[Prompt], - family_tree: dict[str, tuple[str, str] | None], - P: list[list[str]], - T: int, - N: int, - task, - model: Llama2 | OpenAI, - evaluation_usage: ModelUsage, - evolution_usage: ModelUsage, - run_options: dict[str, Any], -): - - with open(run_directory / "snapshot.json", "w") as f: - json.dump( - { - "all_prompts": all_prompts, - "family_tree": family_tree, - "P": P, - "T": T, - "N": N, - "task": { - "name": task.__class__.__name__, - "validation_dataset": task.validation_dataset.info.dataset_name, - "test_dataset": task.test_dataset.info.dataset_name, - "metric": task.metric_name, - "use_grammar": task.use_grammar, - }, - "model": {"name": model.__class__.__name__}, - "evaluation_usage": evaluation_usage, - "evolution_usage": evolution_usage, - "run_options": run_options, - }, - f, - indent=4, - cls=OptTypeEncoder, - ) - - -def load_snapshot(path: Path): - import json - - with path.open("r") as f: - snapshot = json.load(f) - return ( - snapshot["family_tree"], - snapshot["P"], - snapshot["S"], - snapshot["T"], - snapshot["N"], - )