Skip to content
Snippets Groups Projects
Commit d47234e3 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

move cli parameters to local modules

parent 82f55d87
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser from argparse import ArgumentParser
from models import MODELS
# from typing import Literal
# import pydantic
# class Arguments(pydantic.BaseModel):
# # Required Args
# evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine")
# evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm")
# # model_path: str = pydantic.Field(description="the model path")
# task: Literal['sa', 'qa'] = pydantic.Field(description="the model path")
# use_grammar: bool = pydantic.Field(False, description="use grammar")
# debug: bool = pydantic.Field(None, description="use grammar")
# chat: bool = pydantic.Field(False, description="use chat mode")
argument_parser = ArgumentParser() argument_parser = ArgumentParser()
argument_parser.add_argument(
"--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2"
)
argument_parser.add_argument(
"--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga"
)
argument_parser.add_argument(
"--task", "-t", type=str, required=True, choices=["sa", "qa"]
)
argument_parser.add_argument("--use-grammar", "-g", action="store_true")
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_parser.add_argument("--chat", "-c", action="store_true")
argument_parser.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
argument_parser.add_argument(
"--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
)
from abc import abstractmethod from abc import abstractmethod
from models import LLMModel
from numpy.random import choice from numpy.random import choice
from tqdm import trange
from cli import argument_parser
from models import LLMModel
from opt_types import ModelUsage, Prompt from opt_types import ModelUsage, Prompt
from optimization import PromptOptimization, save_snapshot from optimization import PromptOptimization, save_snapshot
from task import Task from task import Task
from tqdm import trange
from utils import initialize_run_directory, log_calls, logger from utils import initialize_run_directory, log_calls, logger
SYSTEM_MESSAGE = ( SYSTEM_MESSAGE = (
...@@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt} ...@@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt}
class EvolutionAlgorithm(PromptOptimization): class EvolutionAlgorithm(PromptOptimization):
# TODO add docstrings shorthand: str
"""The super class for all evolution algorithms containing shared parameters.""" """The super class for all evolution algorithms containing shared parameters."""
def __init__( def __init__(
...@@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization): ...@@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization):
class GeneticAlgorithm(EvolutionAlgorithm): class GeneticAlgorithm(EvolutionAlgorithm):
"""The genetic algorithm implemented using LLMs.""" """The genetic algorithm implemented using LLMs."""
shorthand = "ga"
@log_calls("Performing prompt evolution using GA") @log_calls("Performing prompt evolution using GA")
def evolve( def evolve(
self, self,
...@@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm): ...@@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm):
class DifferentialEvolution(EvolutionAlgorithm): class DifferentialEvolution(EvolutionAlgorithm):
"""The genetic algorithm implemented using LLMs.""" """The genetic algorithm implemented using LLMs."""
shorthand = "de"
@log_calls("Performing prompt evolution using GA") @log_calls("Performing prompt evolution using GA")
def evolve( def evolve(
self, self,
...@@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm): ...@@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm):
) )
] ]
return population return population
optimizers = {
algorithm.shorthand: algorithm for algorithm in EvolutionAlgorithm.__subclasses__()
}
def get_optimizer_class(name: str):
if name not in optimizers:
raise ValueError("Optimization Algorithm %s does not exist", name)
return optimizers[name]
argument_parser.add_argument(
"--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga"
)
...@@ -4,9 +4,9 @@ from typing import Any ...@@ -4,9 +4,9 @@ from typing import Any
from dotenv import load_dotenv from dotenv import load_dotenv
from cli import argument_parser from cli import argument_parser
from evolution import DifferentialEvolution, GeneticAlgorithm from evolution import DifferentialEvolution, GeneticAlgorithm, get_optimizer_class
from models import Llama2, get_model_init from models import Llama2, LLMModel, OpenAI, get_model
from task import QuestionAnswering, SentimentAnalysis from task import QuestionAnswering, SentimentAnalysis, get_task
from utils import logger from utils import logger
load_dotenv() load_dotenv()
...@@ -26,8 +26,7 @@ if __name__ == "__main__": ...@@ -26,8 +26,7 @@ if __name__ == "__main__":
options = argument_parser.parse_args() options = argument_parser.parse_args()
# set up evolution model # set up evolution model
model_init_fn = get_model_init(options.evolution_engine) evolution_model = get_model(options.evolution_engine, options)
evolution_model = model_init_fn(options)
match options.evolution_engine: match options.evolution_engine:
case "llama2": case "llama2":
...@@ -39,6 +38,7 @@ if __name__ == "__main__": ...@@ -39,6 +38,7 @@ if __name__ == "__main__":
# NOTE currenty we always stick to Llama2 as evaluation engine # NOTE currenty we always stick to Llama2 as evaluation engine
# TODO allow to set separate engine and model for evaluation? # TODO allow to set separate engine and model for evaluation?
logger.info("Using Llama2 as the evaluation engine") logger.info("Using Llama2 as the evaluation engine")
evaluation_model: LLMModel
match options.evolution_engine: match options.evolution_engine:
case "llama2": case "llama2":
evaluation_model = evolution_model evaluation_model = evolution_model
...@@ -61,40 +61,14 @@ if __name__ == "__main__": ...@@ -61,40 +61,14 @@ if __name__ == "__main__":
f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG." f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG."
) )
match options.task: task = get_task(options.task, evaluation_model, options)
case "sa": logger.info(
logger.info("Running with task sentiment analysis on dataset SetFit/sst2") f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}"
task = SentimentAnalysis( )
evaluation_model,
"SetFit/sst2",
"SetFit/sst2",
use_grammar=options.use_grammar,
validation_split=f"validation[:{5 if debug else 200}]",
test_split="test[:20]" if debug else "test",
)
case "qa":
logger.info("Running with task question answering on dataset squad")
task = QuestionAnswering(
evaluation_model,
"squad",
"squad",
use_grammar=options.use_grammar,
validation_split=f"train[:{5 if debug else 200}]",
test_split="validation[:20]" if debug else "validation",
)
case _:
raise ValueError(
f"Task {options.task} does not exist. Choose from 'sa', 'qa'."
)
logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm) logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm)
# TODO allow to register algorithms and map to classes optimizer_class = get_optimizer_class(options.evolution_algorithm)
if options.evolution_algorithm == "ga":
optimizer_class = GeneticAlgorithm
else:
optimizer_class = DifferentialEvolution
optimizer = optimizer_class( optimizer = optimizer_class(
population_size=10, population_size=10,
task=task, task=task,
......
from abc import abstractmethod from abc import abstractmethod
from argparse import Namespace from argparse import ArgumentParser, Namespace
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any
import openai import openai
from llama_cpp import Llama from llama_cpp import Llama
from cli import argument_parser
from opt_types import ModelUsage from opt_types import ModelUsage
current_directory = Path(__file__).resolve().parent current_directory = Path(__file__).resolve().parent
...@@ -34,33 +35,12 @@ class LLMModel: ...@@ -34,33 +35,12 @@ class LLMModel:
) -> Any: ) -> Any:
pass pass
@classmethod
MODELS: dict[str, type[LLMModel]] = {} @abstractmethod
def register_arguments(cls, parser: ArgumentParser):
pass
def register_model(name: str):
def wrapper(f: type[LLMModel]):
global MODELS
if name in MODELS:
raise ValueError("Cannot register model class %s: already exists", name)
MODELS[name] = f
return f
return wrapper
def get_models():
return MODELS
def get_model_init(name: str):
if name not in MODELS:
raise ValueError("Model %s does not exist", name)
return MODELS[name]
@register_model("llama2")
class Llama2(LLMModel): class Llama2(LLMModel):
"""Loads and queries a Llama2 model.""" """Loads and queries a Llama2 model."""
...@@ -142,8 +122,14 @@ class Llama2(LLMModel): ...@@ -142,8 +122,14 @@ class Llama2(LLMModel):
self.usage += usage self.usage += usage
return response_text, usage return response_text, usage
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("Llama2 model arguments")
group.add_argument(
"--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
)
@register_model("openai")
class OpenAI(LLMModel): class OpenAI(LLMModel):
"""Queries an OpenAI model using its API.""" """Queries an OpenAI model using its API."""
...@@ -215,3 +201,26 @@ class OpenAI(LLMModel): ...@@ -215,3 +201,26 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__) usage = ModelUsage(**response.usage.__dict__)
self.usage += usage self.usage += usage
return response.choices[0].text, usage return response.choices[0].text, usage
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
models = {model.__name__.lower(): model for model in LLMModel.__subclasses__()}
for name, model in models.items():
model.register_arguments(argument_parser)
def get_model(name: str, options: Namespace):
if name not in models:
raise ValueError("Model %s does not exist", name)
return models[name](options)
argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
"--evolution-engine", "-e", type=str, choices=models.keys(), default="llama2"
)
argument_group.add_argument("--chat", "-c", action="store_true")
import re import re
from abc import abstractmethod from abc import abstractmethod
from argparse import Namespace
from functools import lru_cache from functools import lru_cache
from typing import Union from typing import Union
...@@ -8,7 +9,8 @@ from evaluate import load as load_metric ...@@ -8,7 +9,8 @@ from evaluate import load as load_metric
from llama_cpp import LlamaGrammar from llama_cpp import LlamaGrammar
from tqdm import tqdm from tqdm import tqdm
from models import Llama2, OpenAI from cli import argument_parser
from models import Llama2, LLMModel, OpenAI
from opt_types import ModelUsage from opt_types import ModelUsage
from utils import log_calls, logger from utils import log_calls, logger
...@@ -20,6 +22,7 @@ DatasetDatum = dict ...@@ -20,6 +22,7 @@ DatasetDatum = dict
class Task: class Task:
shorthand: str
validation_dataset: Dataset validation_dataset: Dataset
test_dataset: Dataset test_dataset: Dataset
...@@ -67,9 +70,9 @@ class Task: ...@@ -67,9 +70,9 @@ class Task:
for datum in dataset_iterator: for datum in dataset_iterator:
result, usage = self._evaluate_sample(prompt, datum) result, usage = self._evaluate_sample(prompt, datum)
results.append(result) results.append(result)
current_metrics = self._aggregate_result(results) current_metric = self._aggregate_result(results)
dataset_iterator.set_postfix( dataset_iterator.set_postfix(
{self.metric_name: f"{current_metrics*100:.1f}%"} {self.metric_name: f"{current_metric*100:.1f}%"}
) )
evaluation_usage += usage evaluation_usage += usage
...@@ -109,24 +112,16 @@ def sa_grammar_fn(verbose: bool = False): ...@@ -109,24 +112,16 @@ def sa_grammar_fn(verbose: bool = False):
class SentimentAnalysis(Task): class SentimentAnalysis(Task):
shorthand = "sa"
def __init__( def __init__(self, model, options: Namespace):
self,
model,
validation_dataset: str,
test_dataset: str,
*,
use_grammar: bool,
validation_split: str | None = None,
test_split: str | None = None,
) -> None:
super().__init__( super().__init__(
model, model,
validation_dataset, validation_dataset="SetFit/sst2",
test_dataset, test_dataset="SetFit/sst2",
use_grammar=use_grammar, use_grammar=options.use_grammar,
validation_split=validation_split, validation_split=f"validation[:{5 if options.debug else 200}]",
test_split=test_split, test_split="test[:20]" if options.debug else "test",
) )
def predict(self, prompt: str, text: str): def predict(self, prompt: str, text: str):
...@@ -219,27 +214,19 @@ def qa_grammar_fn(context: str, verbose: bool = False): ...@@ -219,27 +214,19 @@ def qa_grammar_fn(context: str, verbose: bool = False):
class QuestionAnswering(Task): class QuestionAnswering(Task):
shorthand = "qa"
def __init__( def __init__(self, model, options: Namespace):
self,
model,
validation_dataset: str,
test_dataset: str,
*,
use_grammar: bool,
validation_split: str | None = None,
test_split: str | None = None,
) -> None:
self.metric = load_metric("squad") self.metric = load_metric("squad")
super().__init__( super().__init__(
model, model,
validation_dataset, "squad",
test_dataset, "squad",
use_grammar=use_grammar, use_grammar=options.use_grammar,
validation_split=validation_split, validation_split=f"train[:{5 if options.debug else 200}]",
test_split=test_split, test_split="validation[:20]" if options.debug else "validation",
) )
def predict(self, prompt: str, context: str, question: str): def predict(self, prompt: str, context: str, question: str):
...@@ -330,3 +317,20 @@ class QuestionAnswering(Task): ...@@ -330,3 +317,20 @@ class QuestionAnswering(Task):
def base_prompt(self): def base_prompt(self):
# TODO find good prompt # TODO find good prompt
return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
tasks = {task.shorthand: task for task in Task.__subclasses__()}
def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
if name not in tasks:
raise ValueError("Model %s does not exist", name)
return tasks[name](evaluation_model, options)
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_group = argument_parser.add_argument_group("Task arguments")
argument_group.add_argument("--use-grammar", "-g", action="store_true")
argument_group.add_argument(
"--task", "-t", type=str, required=True, choices=["sa", "qa"]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment