diff --git a/cli.py b/cli.py index 29c37807245ca06259964f1032b6408c25f02dde..0bed0f3c1ce8f8b8fdc5dcd50f97af03706639f9 100644 --- a/cli.py +++ b/cli.py @@ -12,3 +12,4 @@ argument_parser.add_argument("--model-path", "-m", type=str, required=True) argument_parser.add_argument( "--task", "-t", type=str, required=True, choices=["sa", "qa"] ) +argument_parser.add_argument("--debug", "-d", action='store_true', default=None) diff --git a/main.py b/main.py index bf58276fdbbaebe57b37c2b25975b117e1b52d35..66dd0a23510adcf5d6ec4c4e69b3551b0bb486e2 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,8 @@ from functools import lru_cache -from functools import lru_cache, partial +import os from pathlib import Path -from typing import DefaultDict, get_type_hints +from typing import Any -from datasets import Dataset, load_dataset -from evaluate import load as load_metric from dotenv import load_dotenv from numpy.random import choice from tqdm import trange @@ -20,6 +18,18 @@ from utils import ( save_genealogy, save_snapshot, ) + + +def conv2bool(_str: Any): + if isinstance(_str, bool): + return _str + if str(_str).lower() in ["1", "true"]: + return True + if str(_str).lower() in ["0", "false"]: + return False + return None + + # whether to use chat model for LLM or not USE_CHAT: bool = False @@ -71,7 +81,7 @@ GA_PROMPT = """ Please follow the instruction step-by-step to generate a better prompt. 1. Cross over the following prompts and generate a new prompt: Prompt 1: {prompt1} -Prompt 2:{prompt2} +Prompt 2: {prompt2} 2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>. """ @@ -146,12 +156,12 @@ def update(prompts: list[str], scores: list[float], N: int): return retained_prompts, retained_scores -def run_episode(evo_alg_str: str): +def run_episode(evo_alg_str: str, debug: bool = False): # Algorithm 1 Discrete prompt optimization: EVOPROMPT # Require: # - Size of population - N = 10 + N = 3 if debug else 10 # - Initial prompts P0 = {p1, p2, . . . , pN } paraphrases = paraphrase_prompts(task.base_prompt, n=N - 1) # the initial population @@ -161,7 +171,7 @@ def run_episode(evo_alg_str: str): f_D = lru_cache(maxsize=None)(task.evaluate_validation) # - a pre-defined number of iterations T - T = 10 + T = 2 if debug else 10 # - carefully designed evolutionary operators to generate a new prompt Evo(·) @@ -240,9 +250,11 @@ def run_episode(evo_alg_str: str): population, population_scores = list( zip( *[ - (new_prompt, new_prompt_score) - if new_prompt_score > current_prompt_score - else (current_prompt, current_prompt_score) + ( + (new_prompt, new_prompt_score) + if new_prompt_score > current_prompt_score + else (current_prompt, current_prompt_score) + ) for current_prompt, current_prompt_score, new_prompt, new_prompt_score in zip( P[t - 1], S[t - 1], new_evolutions, new_evolutions_scores ) @@ -269,6 +281,14 @@ def run_episode(evo_alg_str: str): if __name__ == "__main__": options = argument_parser.parse_args() + # debug mode will allow for a quick run + if options.debug is None: + debug = conv2bool(os.getenv("EP_DEBUG", False)) + if debug is None: + raise ValueError( + f"{os.getenv('EP_DEBUG')} is not allowed for env variable EP_DEBUG." + ) + # set up evolution model match options.evolution_engine: case "llama2": @@ -299,20 +319,20 @@ if __name__ == "__main__": evaluation_model, "SetFit/sst2", "SetFit/sst2", - validation_split="validation", - test_split="test", + validation_split=f"validation[:{5 if debug else 200}]", + test_split="test[:20]" if debug else "test", ) case "qa": task = QuestionAnswering( evaluation_model, "squad", "squad", - validation_split=f"train[:{5 if DEBUG else 200}]", - test_split="validation[:20]" if DEBUG else "validation", + validation_split=f"train[:{5 if debug else 200}]", + test_split="validation[:20]" if debug else "validation", ) case _: raise ValueError( f"Task {options.task} does not exist. Choose from 'sa', 'qa'." ) - run_episode(evo_alg_str=options.evolution_algorithm) + run_episode(evo_alg_str=options.evolution_algorithm, debug=debug) diff --git a/models.py b/models.py index fc9949d82ba7cccd3c487b2223af22a1fefa006c..a640c53e7d49ccd576451eec76ce2595450c604c 100644 --- a/models.py +++ b/models.py @@ -71,7 +71,7 @@ class Llama2: class OpenAI: - """Loads and queries an OpenAI model.""" + """Queries an OpenAI model using its API.""" def __init__( self, model: str, chat: bool = False, verbose: bool = False, **kwargs