From d47234e3548c2e2df8d6661cf3054c0bdd481529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Thu, 25 Apr 2024 10:52:14 +0200 Subject: [PATCH] move cli parameters to local modules --- cli.py | 34 ------------------------- evolution.py | 29 +++++++++++++++++++--- main.py | 46 ++++++++-------------------------- models.py | 65 +++++++++++++++++++++++++++--------------------- task.py | 70 +++++++++++++++++++++++++++------------------------- 5 files changed, 110 insertions(+), 134 deletions(-) diff --git a/cli.py b/cli.py index 89a875c..cc167df 100644 --- a/cli.py +++ b/cli.py @@ -1,37 +1,3 @@ from argparse import ArgumentParser -from models import MODELS - -# from typing import Literal - -# import pydantic - -# class Arguments(pydantic.BaseModel): -# # Required Args -# evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine") -# evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm") -# # model_path: str = pydantic.Field(description="the model path") -# task: Literal['sa', 'qa'] = pydantic.Field(description="the model path") -# use_grammar: bool = pydantic.Field(False, description="use grammar") -# debug: bool = pydantic.Field(None, description="use grammar") -# chat: bool = pydantic.Field(False, description="use chat mode") - - argument_parser = ArgumentParser() - -argument_parser.add_argument( - "--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2" -) -argument_parser.add_argument( - "--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga" -) -argument_parser.add_argument( - "--task", "-t", type=str, required=True, choices=["sa", "qa"] -) -argument_parser.add_argument("--use-grammar", "-g", action="store_true") -argument_parser.add_argument("--debug", "-d", action="store_true", default=None) -argument_parser.add_argument("--chat", "-c", action="store_true") -argument_parser.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") -argument_parser.add_argument( - "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf" -) diff --git a/evolution.py b/evolution.py index 200a1b5..4bf51eb 100644 --- a/evolution.py +++ b/evolution.py @@ -1,11 +1,13 @@ from abc import abstractmethod -from models import LLMModel from numpy.random import choice +from tqdm import trange + +from cli import argument_parser +from models import LLMModel from opt_types import ModelUsage, Prompt from optimization import PromptOptimization, save_snapshot from task import Task -from tqdm import trange from utils import initialize_run_directory, log_calls, logger SYSTEM_MESSAGE = ( @@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt} class EvolutionAlgorithm(PromptOptimization): - # TODO add docstrings + shorthand: str + """The super class for all evolution algorithms containing shared parameters.""" def __init__( @@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization): class GeneticAlgorithm(EvolutionAlgorithm): """The genetic algorithm implemented using LLMs.""" + shorthand = "ga" + @log_calls("Performing prompt evolution using GA") def evolve( self, @@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm): class DifferentialEvolution(EvolutionAlgorithm): """The genetic algorithm implemented using LLMs.""" + shorthand = "de" + @log_calls("Performing prompt evolution using GA") def evolve( self, @@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm): ) ] return population + + +optimizers = { + algorithm.shorthand: algorithm for algorithm in EvolutionAlgorithm.__subclasses__() +} + + +def get_optimizer_class(name: str): + if name not in optimizers: + raise ValueError("Optimization Algorithm %s does not exist", name) + return optimizers[name] + + +argument_parser.add_argument( + "--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga" +) diff --git a/main.py b/main.py index 674d80c..eca0a65 100644 --- a/main.py +++ b/main.py @@ -4,9 +4,9 @@ from typing import Any from dotenv import load_dotenv from cli import argument_parser -from evolution import DifferentialEvolution, GeneticAlgorithm -from models import Llama2, get_model_init -from task import QuestionAnswering, SentimentAnalysis +from evolution import DifferentialEvolution, GeneticAlgorithm, get_optimizer_class +from models import Llama2, LLMModel, OpenAI, get_model +from task import QuestionAnswering, SentimentAnalysis, get_task from utils import logger load_dotenv() @@ -26,8 +26,7 @@ if __name__ == "__main__": options = argument_parser.parse_args() # set up evolution model - model_init_fn = get_model_init(options.evolution_engine) - evolution_model = model_init_fn(options) + evolution_model = get_model(options.evolution_engine, options) match options.evolution_engine: case "llama2": @@ -39,6 +38,7 @@ if __name__ == "__main__": # NOTE currenty we always stick to Llama2 as evaluation engine # TODO allow to set separate engine and model for evaluation? logger.info("Using Llama2 as the evaluation engine") + evaluation_model: LLMModel match options.evolution_engine: case "llama2": evaluation_model = evolution_model @@ -61,40 +61,14 @@ if __name__ == "__main__": f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG." ) - match options.task: - case "sa": - logger.info("Running with task sentiment analysis on dataset SetFit/sst2") - task = SentimentAnalysis( - evaluation_model, - "SetFit/sst2", - "SetFit/sst2", - use_grammar=options.use_grammar, - validation_split=f"validation[:{5 if debug else 200}]", - test_split="test[:20]" if debug else "test", - ) - case "qa": - logger.info("Running with task question answering on dataset squad") - task = QuestionAnswering( - evaluation_model, - "squad", - "squad", - use_grammar=options.use_grammar, - validation_split=f"train[:{5 if debug else 200}]", - test_split="validation[:20]" if debug else "validation", - ) - case _: - raise ValueError( - f"Task {options.task} does not exist. Choose from 'sa', 'qa'." - ) + task = get_task(options.task, evaluation_model, options) + logger.info( + f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}" + ) logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm) - # TODO allow to register algorithms and map to classes - if options.evolution_algorithm == "ga": - optimizer_class = GeneticAlgorithm - else: - optimizer_class = DifferentialEvolution - + optimizer_class = get_optimizer_class(options.evolution_algorithm) optimizer = optimizer_class( population_size=10, task=task, diff --git a/models.py b/models.py index 56d5016..2c4cad5 100644 --- a/models.py +++ b/models.py @@ -1,11 +1,12 @@ from abc import abstractmethod -from argparse import Namespace +from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, Callable +from typing import Any import openai from llama_cpp import Llama +from cli import argument_parser from opt_types import ModelUsage current_directory = Path(__file__).resolve().parent @@ -34,33 +35,12 @@ class LLMModel: ) -> Any: pass - -MODELS: dict[str, type[LLMModel]] = {} - - -def register_model(name: str): - def wrapper(f: type[LLMModel]): - global MODELS - if name in MODELS: - raise ValueError("Cannot register model class %s: already exists", name) - MODELS[name] = f - - return f - - return wrapper - - -def get_models(): - return MODELS - - -def get_model_init(name: str): - if name not in MODELS: - raise ValueError("Model %s does not exist", name) - return MODELS[name] + @classmethod + @abstractmethod + def register_arguments(cls, parser: ArgumentParser): + pass -@register_model("llama2") class Llama2(LLMModel): """Loads and queries a Llama2 model.""" @@ -142,8 +122,14 @@ class Llama2(LLMModel): self.usage += usage return response_text, usage + @classmethod + def register_arguments(cls, parser: ArgumentParser): + group = parser.add_argument_group("Llama2 model arguments") + group.add_argument( + "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf" + ) + -@register_model("openai") class OpenAI(LLMModel): """Queries an OpenAI model using its API.""" @@ -215,3 +201,26 @@ class OpenAI(LLMModel): usage = ModelUsage(**response.usage.__dict__) self.usage += usage return response.choices[0].text, usage + + @classmethod + def register_arguments(cls, parser: ArgumentParser): + group = parser.add_argument_group("OpenAI model arguments") + group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") + + +models = {model.__name__.lower(): model for model in LLMModel.__subclasses__()} +for name, model in models.items(): + model.register_arguments(argument_parser) + + +def get_model(name: str, options: Namespace): + if name not in models: + raise ValueError("Model %s does not exist", name) + return models[name](options) + + +argument_group = argument_parser.add_argument_group("Model arguments") +argument_group.add_argument( + "--evolution-engine", "-e", type=str, choices=models.keys(), default="llama2" +) +argument_group.add_argument("--chat", "-c", action="store_true") diff --git a/task.py b/task.py index abbc554..adfaa1f 100644 --- a/task.py +++ b/task.py @@ -1,5 +1,6 @@ import re from abc import abstractmethod +from argparse import Namespace from functools import lru_cache from typing import Union @@ -8,7 +9,8 @@ from evaluate import load as load_metric from llama_cpp import LlamaGrammar from tqdm import tqdm -from models import Llama2, OpenAI +from cli import argument_parser +from models import Llama2, LLMModel, OpenAI from opt_types import ModelUsage from utils import log_calls, logger @@ -20,6 +22,7 @@ DatasetDatum = dict class Task: + shorthand: str validation_dataset: Dataset test_dataset: Dataset @@ -67,9 +70,9 @@ class Task: for datum in dataset_iterator: result, usage = self._evaluate_sample(prompt, datum) results.append(result) - current_metrics = self._aggregate_result(results) + current_metric = self._aggregate_result(results) dataset_iterator.set_postfix( - {self.metric_name: f"{current_metrics*100:.1f}%"} + {self.metric_name: f"{current_metric*100:.1f}%"} ) evaluation_usage += usage @@ -109,24 +112,16 @@ def sa_grammar_fn(verbose: bool = False): class SentimentAnalysis(Task): + shorthand = "sa" - def __init__( - self, - model, - validation_dataset: str, - test_dataset: str, - *, - use_grammar: bool, - validation_split: str | None = None, - test_split: str | None = None, - ) -> None: + def __init__(self, model, options: Namespace): super().__init__( model, - validation_dataset, - test_dataset, - use_grammar=use_grammar, - validation_split=validation_split, - test_split=test_split, + validation_dataset="SetFit/sst2", + test_dataset="SetFit/sst2", + use_grammar=options.use_grammar, + validation_split=f"validation[:{5 if options.debug else 200}]", + test_split="test[:20]" if options.debug else "test", ) def predict(self, prompt: str, text: str): @@ -219,27 +214,19 @@ def qa_grammar_fn(context: str, verbose: bool = False): class QuestionAnswering(Task): + shorthand = "qa" - def __init__( - self, - model, - validation_dataset: str, - test_dataset: str, - *, - use_grammar: bool, - validation_split: str | None = None, - test_split: str | None = None, - ) -> None: + def __init__(self, model, options: Namespace): self.metric = load_metric("squad") super().__init__( model, - validation_dataset, - test_dataset, - use_grammar=use_grammar, - validation_split=validation_split, - test_split=test_split, + "squad", + "squad", + use_grammar=options.use_grammar, + validation_split=f"train[:{5 if options.debug else 200}]", + test_split="validation[:20]" if options.debug else "validation", ) def predict(self, prompt: str, context: str, question: str): @@ -330,3 +317,20 @@ class QuestionAnswering(Task): def base_prompt(self): # TODO find good prompt return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + + +tasks = {task.shorthand: task for task in Task.__subclasses__()} + + +def get_task(name: str, evaluation_model: LLMModel, options: Namespace): + if name not in tasks: + raise ValueError("Model %s does not exist", name) + return tasks[name](evaluation_model, options) + + +argument_parser.add_argument("--debug", "-d", action="store_true", default=None) +argument_group = argument_parser.add_argument_group("Task arguments") +argument_group.add_argument("--use-grammar", "-g", action="store_true") +argument_group.add_argument( + "--task", "-t", type=str, required=True, choices=["sa", "qa"] +) -- GitLab