Skip to content
Snippets Groups Projects
Commit 79a99169 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

Merge branch 'early-stopping' into backend

parents aa579769 4f82a862
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser
from models import MODELS
# from typing import Literal
# import pydantic
# class Arguments(pydantic.BaseModel):
# # Required Args
# evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine")
# evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm")
# # model_path: str = pydantic.Field(description="the model path")
# task: Literal['sa', 'qa'] = pydantic.Field(description="the model path")
# use_grammar: bool = pydantic.Field(False, description="use grammar")
# debug: bool = pydantic.Field(None, description="use grammar")
# chat: bool = pydantic.Field(False, description="use chat mode")
argument_parser = ArgumentParser()
argument_parser.add_argument(
"--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2"
)
argument_parser.add_argument(
"--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga"
)
argument_parser.add_argument("--model", "-m", type=str)
argument_parser.add_argument(
"--task", "-t", type=str, required=True, choices=["sa", "qa"]
)
argument_parser.add_argument("--use-grammar", "-g", action="store_true")
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_parser.add_argument("--chat", "-c", action="store_true")
argument_parser = ArgumentParser(conflict_handler="resolve")
TARGET_PATH='models/llama-2-13b-chat.Q5_K_M.gguf'
TARGET_PATH='models/'
mkdir -p "$(dirname "${TARGET_PATH}")"
curl -L "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf?download=true" -o ${TARGET_PATH}
curl -L "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf?download=true" -o ${TARGET_PATH}/llama-2-13b-chat.Q5_K_M.gguf
curl -L "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/blob/main/Meta-Llama-3-8B.Q8_0.gguf?download=true" -o ${TARGET_PATH}/Meta-Llama-3-8B.Q8_0.gguf
from abc import abstractmethod
from models import LLMModel
from numpy.random import choice
from tqdm import trange
from cli import argument_parser
from models import LLMModel
from opt_types import ModelUsage, Prompt
from optimization import PromptOptimization, save_snapshot
from task import Task
from tqdm import trange
from utils import initialize_run_directory, log_calls, logger
SYSTEM_MESSAGE = (
......@@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt}
class EvolutionAlgorithm(PromptOptimization):
# TODO add docstrings
shorthand: str
"""The super class for all evolution algorithms containing shared parameters."""
def __init__(
......@@ -166,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization):
logger.info(f"Best prompt: {p}")
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_performance, _ = self.task.evaluate_test(p.content)
test_performance, _, _ = self.task.evaluate_test(p.content)
logger.info("Best prompt on test set: %s", test_performance)
logger.info(
"Usage (evolution model / evaluation model / total): %s / %s / %s",
......@@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization):
class GeneticAlgorithm(EvolutionAlgorithm):
"""The genetic algorithm implemented using LLMs."""
shorthand = "ga"
@log_calls("Performing prompt evolution using GA")
def evolve(
self,
......@@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm):
class DifferentialEvolution(EvolutionAlgorithm):
"""The genetic algorithm implemented using LLMs."""
shorthand = "de"
@log_calls("Performing prompt evolution using GA")
def evolve(
self,
......@@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm):
)
]
return population
optimizers = {
algorithm.shorthand: algorithm for algorithm in EvolutionAlgorithm.__subclasses__()
}
def get_optimizer_class(name: str):
if name not in optimizers:
raise ValueError("Optimization Algorithm %s does not exist", name)
return optimizers[name]
argument_parser.add_argument(
"--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga"
)
import os
from typing import Any
from cli import argument_parser
from dotenv import load_dotenv
from evolution import DifferentialEvolution, GeneticAlgorithm
from models import Llama2, get_model_init
from task import QuestionAnswering, SentimentAnalysis
from cli import argument_parser
from evolution import get_optimizer_class
from models import Llama3, LLMModel
from task import get_task
from utils import logger
load_dotenv()
......@@ -25,31 +26,24 @@ if __name__ == "__main__":
options = argument_parser.parse_args()
# set up evolution model
model_init_fn = get_model_init(options.evolution_engine)
evolution_model = model_init_fn(
options.model,
chat=options.chat,
)
evolution_model = LLMModel.get_model(options.evolution_engine, options)
match options.evolution_engine:
case "llama2":
logger.info("Using Llama2 client as the evolution engine")
logger.info("Using Llama2 as the evolution engine")
case "openai":
logger.info("Using OpenAI client as the evolution engine")
logger.info(f"Using {options.openai_model} as the evolution engine")
# set up evaluation model
# NOTE currenty we always stick to Llama2 as evaluation engine
# TODO allow to set separate engine and model for evaluation?
logger.info("Using Llama2 client as the evaluation engine")
logger.info("Using Llama2 as the evaluation engine")
evaluation_model: LLMModel
match options.evolution_engine:
case "llama2":
case "llama2" | "llama3":
evaluation_model = evolution_model
case "openai":
evaluation_model = Llama2(
model_path=options.model,
chat=options.chat,
)
evaluation_model = Llama3(options)
# log cli arguments
logger.info(
......@@ -67,40 +61,14 @@ if __name__ == "__main__":
f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG."
)
match options.task:
case "sa":
logger.info("Running with task sentiment analysis on dataset SetFit/sst2")
task = SentimentAnalysis(
evaluation_model,
"SetFit/sst2",
"SetFit/sst2",
use_grammar=options.use_grammar,
validation_split=f"validation[:{5 if debug else 200}]",
test_split="test[:20]" if debug else "test",
)
case "qa":
logger.info("Running with task question answering on dataset squad")
task = QuestionAnswering(
evaluation_model,
"squad",
"squad",
use_grammar=options.use_grammar,
validation_split=f"train[:{5 if debug else 200}]",
test_split="validation[:20]" if debug else "validation",
)
case _:
raise ValueError(
f"Task {options.task} does not exist. Choose from 'sa', 'qa'."
)
task = get_task(options.task, evaluation_model, options)
logger.info(
f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}"
)
logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm)
# TODO allow to register algorithms and map to classes
if options.evolution_algorithm == "ga":
optimizer_class = GeneticAlgorithm
else:
optimizer_class = DifferentialEvolution
optimizer_class = get_optimizer_class(options.evolution_algorithm)
optimizer = optimizer_class(
population_size=10,
task=task,
......
from abc import abstractmethod
import abc
import inspect
from abc import ABC, abstractmethod, abstractproperty
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Callable
from typing import Any, ClassVar
import openai
from llama_cpp import Llama
from cli import argument_parser
from opt_types import ModelUsage
current_directory = Path(__file__).resolve().parent
MODELS = {}
def register_model(name: str):
def wrapper(f: Callable):
global MODELS
if name in MODELS:
raise ValueError("Cannot register model class %s: already exists", name)
MODELS[name] = f
return f
return wrapper
def get_models():
return MODELS
def get_model_init(name: str):
if name not in MODELS:
raise ValueError("Model %s does not exist", name)
return MODELS[name]
class LLMModel(ABC):
models: ClassVar[dict[str, type["LLMModel"]]] = {}
chat: bool
def __init_subclass__(cls) -> None:
if inspect.isabstract(cls):
return
cls.models[cls.__name__.lower()] = cls
cls.register_arguments(argument_parser)
class LLMModel:
chat: bool
model: Any
@classmethod
def get_model(cls, name: str, options: Namespace):
if name not in cls.models:
raise ValueError("Model %s does not exist", name)
return cls.models[name](options)
def __init__(self, chat: bool, model: Any):
def __init__(self, options: Namespace):
self.usage = ModelUsage()
self.chat = chat
self.model = model
self.chat = options.chat
@abstractmethod
def __call__(
......@@ -59,15 +50,17 @@ class LLMModel:
) -> Any:
pass
@classmethod
@abstractmethod
def register_arguments(cls, parser: ArgumentParser):
pass
@register_model("llama2")
class Llama2(LLMModel):
"""Loads and queries a Llama2 model."""
class LlamaModel(LLMModel):
def __init__(
self,
model_path: str,
chat: bool = False,
options: Namespace,
n_gpu_layers: int = 60,
n_threads: int = 8,
n_ctx: int = 4096,
......@@ -76,16 +69,17 @@ class Llama2(LLMModel):
) -> None:
# initialize model
model = Llama(
model_path,
chat_format="llama-2",
self.model = Llama(
options.llama_path,
chat_format=self.chat_format,
verbose=verbose,
n_gpu_layers=n_gpu_layers,
n_threads=n_threads,
n_ctx=n_ctx,
**kwargs,
)
super().__init__(chat, model)
super().__init__(options)
def __call__(
self,
......@@ -142,19 +136,42 @@ class Llama2(LLMModel):
self.usage += usage
return response_text, usage
@property
@abstractmethod
def chat_format(self) -> str:
pass
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls.__name__} model arguments")
group.add_argument(
"--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
)
class Llama2(LlamaModel):
@property
def chat_format(self) -> str:
return "llama-2"
class Llama3(LlamaModel):
@property
def chat_format(self) -> str:
return "llama-3"
@register_model("openai")
class OpenAI(LLMModel):
"""Queries an OpenAI model using its API."""
def __init__(
self,
model: str = "gpt-3.5-turbo",
chat: bool = False,
options: Namespace,
verbose: bool = False,
**kwargs,
) -> None:
super().__init__(chat, model)
self.model_name = options.openai_model
super().__init__(options)
# initialize client for API calls
self.openai_client = openai.OpenAI(**kwargs)
......@@ -191,7 +208,7 @@ class OpenAI(LLMModel):
},
)
response = self.openai_client.chat.completions.create(
model=self.model,
model=self.model_name,
messages=messages,
stop=stop,
max_tokens=max_tokens,
......@@ -215,3 +232,19 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__)
self.usage += usage
return response.choices[0].text, usage
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
"--evolution-engine",
"-e",
type=str,
choices=LLMModel.models.keys(),
default="llama2",
)
argument_group.add_argument("--chat", "-c", action="store_true")
......@@ -30,6 +30,7 @@ class Prompt:
content: str
score: float
usage: ModelUsage
evaluation_history: list[float]
meta: dict = field(default_factory=dict)
id: str = field(default_factory=lambda: uuid4().hex)
......
......@@ -57,22 +57,36 @@ class PromptOptimization:
def reset(self):
self._init
def evaluate_prompt(self, prompt: str):
return self.task.evaluate_validation(prompt)
def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None):
parent_histories = (
[parent.evaluation_history for parent in parents]
if parents is not None
else None
)
return self.task.evaluate_validation(prompt, parent_histories)
def add_prompt(
self, prompt: str, parents: tuple[Prompt] = None, meta: dict = None
self,
prompt: str,
parents: tuple[Prompt] | None = None,
meta: dict | None = None,
) -> Prompt:
score, usage = self.evaluate_prompt(prompt)
prompt = Prompt(content=prompt, score=score, meta=meta, usage=usage)
score, usage, history = self.evaluate_prompt(prompt, parents)
prompt_object = Prompt(
content=prompt,
score=score,
meta=meta if meta is not None else {},
usage=usage,
evaluation_history=history,
)
# keep track of prompt
self.all_prompts[prompt.id] = prompt
self.family_tree[prompt.id] = (
self.all_prompts[prompt_object.id] = prompt_object
self.family_tree[prompt_object.id] = (
tuple(p.id for p in parents) if parents is not None else None
)
return prompt
return prompt_object
def add_prompts(
self,
......@@ -86,7 +100,7 @@ class PromptOptimization:
]
def get_prompt(self, prompt_id: str):
return self.all_prompt[prompt_id]
return self.all_prompts[prompt_id]
def get_prompts(self, prompt_ids: list[str]):
return [self.get_prompt(p_id) for p_id in prompt_ids]
......
import re
from abc import abstractmethod
from collections import defaultdict
from argparse import Namespace
from functools import lru_cache
from typing import DefaultDict, Mapping, Union
from statistics import mean
from typing import Union
from datasets import Dataset, load_dataset
from evaluate import load as load_metric
from llama_cpp import LlamaGrammar
from models import Llama2, OpenAI
from opt_types import ModelUsage
from llama_cpp import LlamaGrammar, deque
from torch.utils import data
from tqdm import tqdm
from cli import argument_parser
from models import Llama2, LLMModel, OpenAI
from opt_types import ModelUsage
from utils import log_calls, logger
SYSTEM_MESSAGE = """
You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
"""
DatasetDatum = dict
class EarlyStoppingMonitor:
@abstractmethod
def update(self, score: float) -> bool:
raise NotImplementedError
class MomentBasedStopping(EarlyStoppingMonitor):
"""
Watch the first derivative (moment) of the metric to determine when to stop.
"""
def __init__(
self,
*,
patience: int = 10,
start_after: int = 20,
min_moment_magnitude: float = 0.001,
):
self.patience = patience
self.start_after = start_after
self.min_moment_magnitude = min_moment_magnitude
self.moment_magnitudes = deque(maxlen=patience)
self.last_score = 0.0
self.num_calls = 0
def update(self, score: float) -> bool:
# caclulate the current moment (dx/dt)
self.num_calls += 1
if self.num_calls < self.start_after:
return False
self.moment_magnitudes.append(abs(score - self.last_score))
self.last_score = score
if len(self.moment_magnitudes) < self.patience:
return False
if mean(self.moment_magnitudes) < self.min_moment_magnitude:
return True
return False
class ParentBaselineBasedStopping(EarlyStoppingMonitor):
def __init__(
self,
parent_histories: list[list[float]],
*,
patience: int = 10,
start_after: int = 20,
min_improvement: float = 0.001,
):
self.parent_histories = parent_histories
self.patience = patience
self.start_after = start_after
self.min_improvement = min_improvement
self.num_calls = 0
self.improvement_memory = deque(maxlen=patience)
def update(self, score: float) -> bool:
self.num_calls += 1
if self.num_calls < self.start_after:
return False
parent_values = [ # get the metric value of the parents at the current step
(
parent_history[self.num_calls - 1]
if len(parent_history) >= self.num_calls
else parent_history[-1] # extend with last value
)
for parent_history in self.parent_histories
]
self.improvement_memory.append(
score - max(parent_values) # compare with the best parent
)
if len(self.improvement_memory) < self.patience:
return False
if max(self.improvement_memory) < self.min_improvement:
# if the highest improvement is less than the minimum improvement, we stop
return True
return False
class Task:
shorthand: str
validation_dataset: Dataset
test_dataset: Dataset
def __init__(
self,
model: Union[Llama2, OpenAI],
......@@ -43,26 +141,76 @@ class Task:
pass
@abstractmethod
def _evaluate(self, prompt: str, dataset) -> tuple[float, ModelUsage]:
def _evaluate_sample(
self, prompt: str, datum: DatasetDatum
) -> tuple[str, ModelUsage]:
pass
@abstractmethod
def _aggregate_result(self, results: list) -> float:
pass
def evaluate(
self,
prompt: str,
dataset: Dataset,
parent_histories: list[list[float]] | None = None,
) -> tuple[float, ModelUsage, list[float]]:
early_stopping: EarlyStoppingMonitor
early_stopping_params = {
"patience": max(len(dataset) // 20, 5),
"start_after": max(len(dataset) // 5, 5),
}
if parent_histories is not None:
early_stopping = ParentBaselineBasedStopping(
parent_histories, **early_stopping_params
)
else:
early_stopping = MomentBasedStopping(**early_stopping_params)
results: list = []
dataset_iterator: tqdm[DatasetDatum] = tqdm(
dataset, desc="evaluating prompt", leave=False
)
evaluation_usage = ModelUsage()
evaluation_history = []
for datum in dataset_iterator:
result, usage = self._evaluate_sample(prompt, datum)
results.append(result)
current_metric = self._aggregate_result(results)
dataset_iterator.set_postfix(
{self.metric_name: f"{current_metric*100:.1f}%"}
)
evaluation_usage += usage
evaluation_history.append(current_metric)
if early_stopping.update(current_metric):
logger.info(
f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%"
)
break
return self._aggregate_result(results), evaluation_usage, evaluation_history
@log_calls("Evaluating validation dataset")
@lru_cache(maxsize=None)
def evaluate_validation(self, prompt: str):
return self._evaluate(prompt, self.validation_dataset)
def evaluate_validation(
self, prompt: str, parent_histories: list[list[float]] | None = None
):
return self.evaluate(prompt, self.validation_dataset, parent_histories)
@log_calls("Evaluating test dataset")
def evaluate_test(self, prompt: str):
return self._evaluate(prompt, self.test_dataset)
return self.evaluate(prompt, self.test_dataset)
@property
@abstractmethod
def metric_name(self):
def metric_name(self) -> str:
pass
@property
@abstractmethod
def base_prompt(self):
def base_prompt(self) -> str:
pass
......@@ -80,24 +228,16 @@ def sa_grammar_fn(verbose: bool = False):
class SentimentAnalysis(Task):
shorthand = "sa"
def __init__(
self,
model,
validation_dataset: str,
test_dataset: str,
*,
use_grammar: bool,
validation_split: str | None = None,
test_split: str | None = None,
) -> None:
def __init__(self, model, options: Namespace):
super().__init__(
model,
validation_dataset,
test_dataset,
use_grammar=use_grammar,
validation_split=validation_split,
test_split=test_split,
validation_dataset="SetFit/sst2",
test_dataset="SetFit/sst2",
use_grammar=options.use_grammar,
validation_split=f"validation[:{5 if options.debug else 200}]",
test_split="test[:20]" if options.debug else "test",
)
def predict(self, prompt: str, text: str):
......@@ -117,39 +257,33 @@ class SentimentAnalysis(Task):
return response, usage
def _evaluate(self, prompt: str, dataset: Dataset):
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
sst2_labels = {"negative": 0, "positive": 1}
results: DefaultDict[str, int] = defaultdict(int)
dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
evaluation_usage = ModelUsage()
for datum in dataset_iterator:
response, usage = self.predict(prompt=prompt, text=datum["text"])
response = response.lower()
evaluation_usage += usage
if self.use_grammar:
# model output is from label space
answer_label = sst2_labels[response]
response, usage = self.predict(prompt=prompt, text=datum["text"])
response = response.lower()
if self.use_grammar:
# model output is from label space
answer_label = sst2_labels[response]
else:
answer_label = None
for label in sst2_labels.keys():
if label in response:
answer_label = sst2_labels[label]
break
else:
answer_label = None
for label in sst2_labels.keys():
if label in response:
answer_label = sst2_labels[label]
break
else:
logger.warning(f"Invalid answer: {response}")
results["failed"] += 1
continue
classification_result = (
"incorrect" if answer_label != datum["label"] else "correct"
)
results[classification_result] += 1
dataset_iterator.set_postfix(results)
logger.warning(f"Invalid answer: {response}")
return "failed", usage
classification_result = (
"incorrect" if answer_label != datum["label"] else "correct"
)
return classification_result, usage
accuracy = results["correct"] / sum(results.values())
return accuracy, evaluation_usage
def _aggregate_result(self, results: list[str]) -> float:
num_correct_results = sum(1 for result in results if result == "correct")
accuracy = num_correct_results / len(results)
return accuracy
@property
def metric_name(self):
......@@ -196,27 +330,19 @@ def qa_grammar_fn(context: str, verbose: bool = False):
class QuestionAnswering(Task):
shorthand = "qa"
def __init__(
self,
model,
validation_dataset: str,
test_dataset: str,
*,
use_grammar: bool,
validation_split: str | None = None,
test_split: str | None = None,
) -> None:
def __init__(self, model, options: Namespace):
self.metric = load_metric("squad")
super().__init__(
model,
validation_dataset,
test_dataset,
use_grammar=use_grammar,
validation_split=validation_split,
test_split=test_split,
"squad",
"squad",
use_grammar=options.use_grammar,
validation_split=f"train[:{5 if options.debug else 200}]",
test_split="validation[:20]" if options.debug else "validation",
)
def predict(self, prompt: str, context: str, question: str):
......@@ -255,10 +381,26 @@ class QuestionAnswering(Task):
return response, usage
def _evaluate(self, prompt: str, dataset: Dataset):
evaluation_usage = ModelUsage()
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
answer, usage = self.predict(
prompt,
context=datum["context"],
question=datum["question"],
)
# TODO check if answer is lower-cased in metric computation
result = self.metric.compute(
predictions=[{"prediction_text": answer, "id": datum["id"]}],
references=[{"answers": datum["answers"], "id": datum["id"]}],
)
return result["f1"] / 100, usage
def _aggregate_result(self, results: list[float]) -> float:
return sum(results) / len(results)
def replace_symbol_for_grammar(sample: Mapping):
def evaluate(self, prompt: str, dataset: Dataset):
def replace_symbol_for_grammar(sample: DatasetDatum):
symbol_replacement_mapping = {
"\u2013": "-",
"\u2014": "-",
......@@ -281,35 +423,7 @@ class QuestionAnswering(Task):
if self.use_grammar:
# NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen)
dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols")
dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
num_samples = 0
f1 = 0.0
em = 0
for datum in dataset_iterator:
answer, usage = self.predict(
prompt,
context=datum["context"],
question=datum["question"],
)
# TODO check if answer is lower-cased in metric computation
evaluation_usage += usage
num_samples += 1
result = self.metric.compute(
predictions=[{"prediction_text": answer, "id": datum["id"]}],
references=[{"answers": datum["answers"], "id": datum["id"]}],
)
f1 += result["f1"]
em += result["exact_match"]
dataset_iterator.set_postfix(
{"f1": f1 / num_samples, "em": em / num_samples}
)
return f1 / num_samples, evaluation_usage
return super().evaluate(prompt, dataset)
@property
def metric_name(self):
......@@ -319,3 +433,20 @@ class QuestionAnswering(Task):
def base_prompt(self):
# TODO find good prompt
return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
tasks = {task.shorthand: task for task in Task.__subclasses__()}
def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
if name not in tasks:
raise ValueError("Model %s does not exist", name)
return tasks[name](evaluation_model, options)
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_group = argument_parser.add_argument_group("Task arguments")
argument_group.add_argument("--use-grammar", "-g", action="store_true")
argument_group.add_argument(
"--task", "-t", type=str, required=True, choices=["sa", "qa"]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment