diff --git a/cli.py b/cli.py index 071c220d14580eeed8cbf0263a894ddd80790ee8..0e421f74ea03f3b6c52b16ef9ade2d489192a98f 100644 --- a/cli.py +++ b/cli.py @@ -1,34 +1,3 @@ from argparse import ArgumentParser -from models import MODELS - -# from typing import Literal - -# import pydantic - -# class Arguments(pydantic.BaseModel): -# # Required Args -# evolution_engine: Literal['llama2', 'openai'] = pydantic.Field("llama2", description="the evolution engine") -# evolution_algorithm: Literal['ga', 'de'] = pydantic.Field(description="the evolution algorithm") -# # model_path: str = pydantic.Field(description="the model path") -# task: Literal['sa', 'qa'] = pydantic.Field(description="the model path") -# use_grammar: bool = pydantic.Field(False, description="use grammar") -# debug: bool = pydantic.Field(None, description="use grammar") -# chat: bool = pydantic.Field(False, description="use chat mode") - - -argument_parser = ArgumentParser() - -argument_parser.add_argument( - "--evolution-engine", "-e", type=str, choices=MODELS.keys(), default="llama2" -) -argument_parser.add_argument( - "--evolution-algorithm", "-a", type=str, choices=["ga", "de"], default="ga" -) -argument_parser.add_argument("--model", "-m", type=str) -argument_parser.add_argument( - "--task", "-t", type=str, required=True, choices=["sa", "qa"] -) -argument_parser.add_argument("--use-grammar", "-g", action="store_true") -argument_parser.add_argument("--debug", "-d", action="store_true", default=None) -argument_parser.add_argument("--chat", "-c", action="store_true") +argument_parser = ArgumentParser(conflict_handler="resolve") diff --git a/download_model.sh b/download_model.sh index 3ef764bed81f136bc3fbd1258cb49c270cfc4a4f..5f6ab36572aba6bdb803933db2185110d925c94a 100644 --- a/download_model.sh +++ b/download_model.sh @@ -1,3 +1,4 @@ -TARGET_PATH='models/llama-2-13b-chat.Q5_K_M.gguf' +TARGET_PATH='models/' mkdir -p "$(dirname "${TARGET_PATH}")" -curl -L "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf?download=true" -o ${TARGET_PATH} +curl -L "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf?download=true" -o ${TARGET_PATH}/llama-2-13b-chat.Q5_K_M.gguf +curl -L "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/blob/main/Meta-Llama-3-8B.Q8_0.gguf?download=true" -o ${TARGET_PATH}/Meta-Llama-3-8B.Q8_0.gguf diff --git a/evolution.py b/evolution.py index 200a1b5bdd261c0190898bcb026a540b7dc37a16..d2c3e1fcef407ebc3c70e23c6420d9aeb7230062 100644 --- a/evolution.py +++ b/evolution.py @@ -1,11 +1,13 @@ from abc import abstractmethod -from models import LLMModel from numpy.random import choice +from tqdm import trange + +from cli import argument_parser +from models import LLMModel from opt_types import ModelUsage, Prompt from optimization import PromptOptimization, save_snapshot from task import Task -from tqdm import trange from utils import initialize_run_directory, log_calls, logger SYSTEM_MESSAGE = ( @@ -33,7 +35,8 @@ Basic Prompt: {basic_prompt} class EvolutionAlgorithm(PromptOptimization): - # TODO add docstrings + shorthand: str + """The super class for all evolution algorithms containing shared parameters.""" def __init__( @@ -166,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization): logger.info(f"Best prompt: {p}") # We pick the prompt with the highest score on the development set and report its score on the testset. - test_performance, _ = self.task.evaluate_test(p.content) + test_performance, _, _ = self.task.evaluate_test(p.content) logger.info("Best prompt on test set: %s", test_performance) logger.info( "Usage (evolution model / evaluation model / total): %s / %s / %s", @@ -181,6 +184,8 @@ class EvolutionAlgorithm(PromptOptimization): class GeneticAlgorithm(EvolutionAlgorithm): """The genetic algorithm implemented using LLMs.""" + shorthand = "ga" + @log_calls("Performing prompt evolution using GA") def evolve( self, @@ -233,6 +238,8 @@ class GeneticAlgorithm(EvolutionAlgorithm): class DifferentialEvolution(EvolutionAlgorithm): """The genetic algorithm implemented using LLMs.""" + shorthand = "de" + @log_calls("Performing prompt evolution using GA") def evolve( self, @@ -275,3 +282,19 @@ class DifferentialEvolution(EvolutionAlgorithm): ) ] return population + + +optimizers = { + algorithm.shorthand: algorithm for algorithm in EvolutionAlgorithm.__subclasses__() +} + + +def get_optimizer_class(name: str): + if name not in optimizers: + raise ValueError("Optimization Algorithm %s does not exist", name) + return optimizers[name] + + +argument_parser.add_argument( + "--evolution-algorithm", "-a", type=str, choices=optimizers.keys(), default="ga" +) diff --git a/main.py b/main.py index 4a139592384d88bce4fb0da3786aff2718a9143d..4ca5c5128953d3e98302bae41785f8e4c3774d46 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,12 @@ import os from typing import Any -from cli import argument_parser from dotenv import load_dotenv -from evolution import DifferentialEvolution, GeneticAlgorithm -from models import Llama2, get_model_init -from task import QuestionAnswering, SentimentAnalysis + +from cli import argument_parser +from evolution import get_optimizer_class +from models import Llama3, LLMModel +from task import get_task from utils import logger load_dotenv() @@ -25,31 +26,24 @@ if __name__ == "__main__": options = argument_parser.parse_args() # set up evolution model - model_init_fn = get_model_init(options.evolution_engine) - evolution_model = model_init_fn( - options.model, - chat=options.chat, - ) + evolution_model = LLMModel.get_model(options.evolution_engine, options) match options.evolution_engine: case "llama2": - logger.info("Using Llama2 client as the evolution engine") + logger.info("Using Llama2 as the evolution engine") case "openai": - logger.info("Using OpenAI client as the evolution engine") - + logger.info(f"Using {options.openai_model} as the evolution engine") # set up evaluation model # NOTE currenty we always stick to Llama2 as evaluation engine # TODO allow to set separate engine and model for evaluation? - logger.info("Using Llama2 client as the evaluation engine") + logger.info("Using Llama2 as the evaluation engine") + evaluation_model: LLMModel match options.evolution_engine: - case "llama2": + case "llama2" | "llama3": evaluation_model = evolution_model case "openai": - evaluation_model = Llama2( - model_path=options.model, - chat=options.chat, - ) + evaluation_model = Llama3(options) # log cli arguments logger.info( @@ -67,40 +61,14 @@ if __name__ == "__main__": f"'{os.getenv('EP_DEBUG')}' is not allowed for env variable EP_DEBUG." ) - match options.task: - case "sa": - logger.info("Running with task sentiment analysis on dataset SetFit/sst2") - task = SentimentAnalysis( - evaluation_model, - "SetFit/sst2", - "SetFit/sst2", - use_grammar=options.use_grammar, - validation_split=f"validation[:{5 if debug else 200}]", - test_split="test[:20]" if debug else "test", - ) - case "qa": - logger.info("Running with task question answering on dataset squad") - task = QuestionAnswering( - evaluation_model, - "squad", - "squad", - use_grammar=options.use_grammar, - validation_split=f"train[:{5 if debug else 200}]", - test_split="validation[:20]" if debug else "validation", - ) - case _: - raise ValueError( - f"Task {options.task} does not exist. Choose from 'sa', 'qa'." - ) + task = get_task(options.task, evaluation_model, options) + logger.info( + f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}" + ) logger.info("Using evolutionary algorithm '%s'", options.evolution_algorithm) - # TODO allow to register algorithms and map to classes - if options.evolution_algorithm == "ga": - optimizer_class = GeneticAlgorithm - else: - optimizer_class = DifferentialEvolution - + optimizer_class = get_optimizer_class(options.evolution_algorithm) optimizer = optimizer_class( population_size=10, task=task, diff --git a/models.py b/models.py index f6f9fd6afa94004914f98e4db8f003df25f044a6..bf0d49cad99956612e3ca2a060b8b3bc562d2a54 100644 --- a/models.py +++ b/models.py @@ -1,47 +1,38 @@ -from abc import abstractmethod +import abc +import inspect +from abc import ABC, abstractmethod, abstractproperty +from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, Callable +from typing import Any, ClassVar import openai from llama_cpp import Llama + +from cli import argument_parser from opt_types import ModelUsage current_directory = Path(__file__).resolve().parent -MODELS = {} - - -def register_model(name: str): - def wrapper(f: Callable): - global MODELS - if name in MODELS: - raise ValueError("Cannot register model class %s: already exists", name) - MODELS[name] = f - - return f - - return wrapper - - -def get_models(): - return MODELS - - -def get_model_init(name: str): - if name not in MODELS: - raise ValueError("Model %s does not exist", name) - return MODELS[name] +class LLMModel(ABC): + models: ClassVar[dict[str, type["LLMModel"]]] = {} + chat: bool + def __init_subclass__(cls) -> None: + if inspect.isabstract(cls): + return + cls.models[cls.__name__.lower()] = cls + cls.register_arguments(argument_parser) -class LLMModel: - chat: bool - model: Any + @classmethod + def get_model(cls, name: str, options: Namespace): + if name not in cls.models: + raise ValueError("Model %s does not exist", name) + return cls.models[name](options) - def __init__(self, chat: bool, model: Any): + def __init__(self, options: Namespace): self.usage = ModelUsage() - self.chat = chat - self.model = model + self.chat = options.chat @abstractmethod def __call__( @@ -59,15 +50,17 @@ class LLMModel: ) -> Any: pass + @classmethod + @abstractmethod + def register_arguments(cls, parser: ArgumentParser): + pass + -@register_model("llama2") -class Llama2(LLMModel): - """Loads and queries a Llama2 model.""" +class LlamaModel(LLMModel): def __init__( self, - model_path: str, - chat: bool = False, + options: Namespace, n_gpu_layers: int = 60, n_threads: int = 8, n_ctx: int = 4096, @@ -76,16 +69,17 @@ class Llama2(LLMModel): ) -> None: # initialize model - model = Llama( - model_path, - chat_format="llama-2", + self.model = Llama( + options.llama_path, + chat_format=self.chat_format, verbose=verbose, n_gpu_layers=n_gpu_layers, n_threads=n_threads, n_ctx=n_ctx, **kwargs, ) - super().__init__(chat, model) + + super().__init__(options) def __call__( self, @@ -142,19 +136,42 @@ class Llama2(LLMModel): self.usage += usage return response_text, usage + @property + @abstractmethod + def chat_format(self) -> str: + pass + + @classmethod + def register_arguments(cls, parser: ArgumentParser): + group = parser.add_argument_group(f"{cls.__name__} model arguments") + group.add_argument( + "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf" + ) + + +class Llama2(LlamaModel): + @property + def chat_format(self) -> str: + return "llama-2" + + +class Llama3(LlamaModel): + @property + def chat_format(self) -> str: + return "llama-3" + -@register_model("openai") class OpenAI(LLMModel): """Queries an OpenAI model using its API.""" def __init__( self, - model: str = "gpt-3.5-turbo", - chat: bool = False, + options: Namespace, verbose: bool = False, **kwargs, ) -> None: - super().__init__(chat, model) + self.model_name = options.openai_model + super().__init__(options) # initialize client for API calls self.openai_client = openai.OpenAI(**kwargs) @@ -191,7 +208,7 @@ class OpenAI(LLMModel): }, ) response = self.openai_client.chat.completions.create( - model=self.model, + model=self.model_name, messages=messages, stop=stop, max_tokens=max_tokens, @@ -215,3 +232,19 @@ class OpenAI(LLMModel): usage = ModelUsage(**response.usage.__dict__) self.usage += usage return response.choices[0].text, usage + + @classmethod + def register_arguments(cls, parser: ArgumentParser): + group = parser.add_argument_group("OpenAI model arguments") + group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") + + +argument_group = argument_parser.add_argument_group("Model arguments") +argument_group.add_argument( + "--evolution-engine", + "-e", + type=str, + choices=LLMModel.models.keys(), + default="llama2", +) +argument_group.add_argument("--chat", "-c", action="store_true") diff --git a/opt_types.py b/opt_types.py index 3d46de80719d2ff86025b92880b51e04044b113f..a80c850ff7032e393cfbfc73c072769f857a670d 100644 --- a/opt_types.py +++ b/opt_types.py @@ -30,6 +30,7 @@ class Prompt: content: str score: float usage: ModelUsage + evaluation_history: list[float] meta: dict = field(default_factory=dict) id: str = field(default_factory=lambda: uuid4().hex) diff --git a/optimization.py b/optimization.py index 417c12cbebac486ab0d94bf5c8027c3ac893bb72..f5592fcf1bdac49076af5e80730c02a309e35378 100644 --- a/optimization.py +++ b/optimization.py @@ -57,22 +57,36 @@ class PromptOptimization: def reset(self): self._init - def evaluate_prompt(self, prompt: str): - return self.task.evaluate_validation(prompt) + def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None): + parent_histories = ( + [parent.evaluation_history for parent in parents] + if parents is not None + else None + ) + return self.task.evaluate_validation(prompt, parent_histories) def add_prompt( - self, prompt: str, parents: tuple[Prompt] = None, meta: dict = None + self, + prompt: str, + parents: tuple[Prompt] | None = None, + meta: dict | None = None, ) -> Prompt: - score, usage = self.evaluate_prompt(prompt) - prompt = Prompt(content=prompt, score=score, meta=meta, usage=usage) + score, usage, history = self.evaluate_prompt(prompt, parents) + prompt_object = Prompt( + content=prompt, + score=score, + meta=meta if meta is not None else {}, + usage=usage, + evaluation_history=history, + ) # keep track of prompt - self.all_prompts[prompt.id] = prompt - self.family_tree[prompt.id] = ( + self.all_prompts[prompt_object.id] = prompt_object + self.family_tree[prompt_object.id] = ( tuple(p.id for p in parents) if parents is not None else None ) - return prompt + return prompt_object def add_prompts( self, @@ -86,7 +100,7 @@ class PromptOptimization: ] def get_prompt(self, prompt_id: str): - return self.all_prompt[prompt_id] + return self.all_prompts[prompt_id] def get_prompts(self, prompt_ids: list[str]): return [self.get_prompt(p_id) for p_id in prompt_ids] diff --git a/task.py b/task.py index 461348d35be3b41dbb79986ca5df87e47640056b..d062253eeee5e38db3388d5879da7064d3afee4b 100644 --- a/task.py +++ b/task.py @@ -1,23 +1,121 @@ import re from abc import abstractmethod -from collections import defaultdict +from argparse import Namespace from functools import lru_cache -from typing import DefaultDict, Mapping, Union +from statistics import mean +from typing import Union from datasets import Dataset, load_dataset from evaluate import load as load_metric -from llama_cpp import LlamaGrammar -from models import Llama2, OpenAI -from opt_types import ModelUsage +from llama_cpp import LlamaGrammar, deque +from torch.utils import data from tqdm import tqdm + +from cli import argument_parser +from models import Llama2, LLMModel, OpenAI +from opt_types import ModelUsage from utils import log_calls, logger SYSTEM_MESSAGE = """ You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. """ +DatasetDatum = dict + + +class EarlyStoppingMonitor: + + @abstractmethod + def update(self, score: float) -> bool: + raise NotImplementedError + + +class MomentBasedStopping(EarlyStoppingMonitor): + """ + Watch the first derivative (moment) of the metric to determine when to stop. + """ + + def __init__( + self, + *, + patience: int = 10, + start_after: int = 20, + min_moment_magnitude: float = 0.001, + ): + self.patience = patience + self.start_after = start_after + self.min_moment_magnitude = min_moment_magnitude + + self.moment_magnitudes = deque(maxlen=patience) + self.last_score = 0.0 + self.num_calls = 0 + + def update(self, score: float) -> bool: + # caclulate the current moment (dx/dt) + self.num_calls += 1 + if self.num_calls < self.start_after: + return False + + self.moment_magnitudes.append(abs(score - self.last_score)) + self.last_score = score + if len(self.moment_magnitudes) < self.patience: + return False + + if mean(self.moment_magnitudes) < self.min_moment_magnitude: + return True + + return False + + +class ParentBaselineBasedStopping(EarlyStoppingMonitor): + + def __init__( + self, + parent_histories: list[list[float]], + *, + patience: int = 10, + start_after: int = 20, + min_improvement: float = 0.001, + ): + self.parent_histories = parent_histories + self.patience = patience + self.start_after = start_after + self.min_improvement = min_improvement + self.num_calls = 0 + self.improvement_memory = deque(maxlen=patience) + + def update(self, score: float) -> bool: + self.num_calls += 1 + if self.num_calls < self.start_after: + return False + + parent_values = [ # get the metric value of the parents at the current step + ( + parent_history[self.num_calls - 1] + if len(parent_history) >= self.num_calls + else parent_history[-1] # extend with last value + ) + for parent_history in self.parent_histories + ] + self.improvement_memory.append( + score - max(parent_values) # compare with the best parent + ) + + if len(self.improvement_memory) < self.patience: + return False + + if max(self.improvement_memory) < self.min_improvement: + # if the highest improvement is less than the minimum improvement, we stop + return True + + return False + class Task: + shorthand: str + validation_dataset: Dataset + test_dataset: Dataset + def __init__( self, model: Union[Llama2, OpenAI], @@ -43,26 +141,76 @@ class Task: pass @abstractmethod - def _evaluate(self, prompt: str, dataset) -> tuple[float, ModelUsage]: + def _evaluate_sample( + self, prompt: str, datum: DatasetDatum + ) -> tuple[str, ModelUsage]: pass + @abstractmethod + def _aggregate_result(self, results: list) -> float: + pass + + def evaluate( + self, + prompt: str, + dataset: Dataset, + parent_histories: list[list[float]] | None = None, + ) -> tuple[float, ModelUsage, list[float]]: + + early_stopping: EarlyStoppingMonitor + early_stopping_params = { + "patience": max(len(dataset) // 20, 5), + "start_after": max(len(dataset) // 5, 5), + } + if parent_histories is not None: + early_stopping = ParentBaselineBasedStopping( + parent_histories, **early_stopping_params + ) + else: + early_stopping = MomentBasedStopping(**early_stopping_params) + + results: list = [] + dataset_iterator: tqdm[DatasetDatum] = tqdm( + dataset, desc="evaluating prompt", leave=False + ) + evaluation_usage = ModelUsage() + evaluation_history = [] + + for datum in dataset_iterator: + result, usage = self._evaluate_sample(prompt, datum) + results.append(result) + current_metric = self._aggregate_result(results) + dataset_iterator.set_postfix( + {self.metric_name: f"{current_metric*100:.1f}%"} + ) + evaluation_usage += usage + evaluation_history.append(current_metric) + if early_stopping.update(current_metric): + logger.info( + f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%" + ) + break + + return self._aggregate_result(results), evaluation_usage, evaluation_history + @log_calls("Evaluating validation dataset") - @lru_cache(maxsize=None) - def evaluate_validation(self, prompt: str): - return self._evaluate(prompt, self.validation_dataset) + def evaluate_validation( + self, prompt: str, parent_histories: list[list[float]] | None = None + ): + return self.evaluate(prompt, self.validation_dataset, parent_histories) @log_calls("Evaluating test dataset") def evaluate_test(self, prompt: str): - return self._evaluate(prompt, self.test_dataset) + return self.evaluate(prompt, self.test_dataset) @property @abstractmethod - def metric_name(self): + def metric_name(self) -> str: pass @property @abstractmethod - def base_prompt(self): + def base_prompt(self) -> str: pass @@ -80,24 +228,16 @@ def sa_grammar_fn(verbose: bool = False): class SentimentAnalysis(Task): + shorthand = "sa" - def __init__( - self, - model, - validation_dataset: str, - test_dataset: str, - *, - use_grammar: bool, - validation_split: str | None = None, - test_split: str | None = None, - ) -> None: + def __init__(self, model, options: Namespace): super().__init__( model, - validation_dataset, - test_dataset, - use_grammar=use_grammar, - validation_split=validation_split, - test_split=test_split, + validation_dataset="SetFit/sst2", + test_dataset="SetFit/sst2", + use_grammar=options.use_grammar, + validation_split=f"validation[:{5 if options.debug else 200}]", + test_split="test[:20]" if options.debug else "test", ) def predict(self, prompt: str, text: str): @@ -117,39 +257,33 @@ class SentimentAnalysis(Task): return response, usage - def _evaluate(self, prompt: str, dataset: Dataset): + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): sst2_labels = {"negative": 0, "positive": 1} - results: DefaultDict[str, int] = defaultdict(int) - dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False) - evaluation_usage = ModelUsage() - - for datum in dataset_iterator: - response, usage = self.predict(prompt=prompt, text=datum["text"]) - response = response.lower() - evaluation_usage += usage - if self.use_grammar: - # model output is from label space - answer_label = sst2_labels[response] + response, usage = self.predict(prompt=prompt, text=datum["text"]) + response = response.lower() + if self.use_grammar: + # model output is from label space + answer_label = sst2_labels[response] + else: + answer_label = None + for label in sst2_labels.keys(): + if label in response: + answer_label = sst2_labels[label] + break else: - answer_label = None - for label in sst2_labels.keys(): - if label in response: - answer_label = sst2_labels[label] - break - else: - logger.warning(f"Invalid answer: {response}") - results["failed"] += 1 - continue - - classification_result = ( - "incorrect" if answer_label != datum["label"] else "correct" - ) - results[classification_result] += 1 - dataset_iterator.set_postfix(results) + logger.warning(f"Invalid answer: {response}") + return "failed", usage + + classification_result = ( + "incorrect" if answer_label != datum["label"] else "correct" + ) + return classification_result, usage - accuracy = results["correct"] / sum(results.values()) - return accuracy, evaluation_usage + def _aggregate_result(self, results: list[str]) -> float: + num_correct_results = sum(1 for result in results if result == "correct") + accuracy = num_correct_results / len(results) + return accuracy @property def metric_name(self): @@ -196,27 +330,19 @@ def qa_grammar_fn(context: str, verbose: bool = False): class QuestionAnswering(Task): + shorthand = "qa" - def __init__( - self, - model, - validation_dataset: str, - test_dataset: str, - *, - use_grammar: bool, - validation_split: str | None = None, - test_split: str | None = None, - ) -> None: + def __init__(self, model, options: Namespace): self.metric = load_metric("squad") super().__init__( model, - validation_dataset, - test_dataset, - use_grammar=use_grammar, - validation_split=validation_split, - test_split=test_split, + "squad", + "squad", + use_grammar=options.use_grammar, + validation_split=f"train[:{5 if options.debug else 200}]", + test_split="validation[:20]" if options.debug else "validation", ) def predict(self, prompt: str, context: str, question: str): @@ -255,10 +381,26 @@ class QuestionAnswering(Task): return response, usage - def _evaluate(self, prompt: str, dataset: Dataset): - evaluation_usage = ModelUsage() + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + answer, usage = self.predict( + prompt, + context=datum["context"], + question=datum["question"], + ) + # TODO check if answer is lower-cased in metric computation + + result = self.metric.compute( + predictions=[{"prediction_text": answer, "id": datum["id"]}], + references=[{"answers": datum["answers"], "id": datum["id"]}], + ) + + return result["f1"] / 100, usage + + def _aggregate_result(self, results: list[float]) -> float: + return sum(results) / len(results) - def replace_symbol_for_grammar(sample: Mapping): + def evaluate(self, prompt: str, dataset: Dataset): + def replace_symbol_for_grammar(sample: DatasetDatum): symbol_replacement_mapping = { "\u2013": "-", "\u2014": "-", @@ -281,35 +423,7 @@ class QuestionAnswering(Task): if self.use_grammar: # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen) dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols") - - dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False) - - num_samples = 0 - f1 = 0.0 - em = 0 - for datum in dataset_iterator: - answer, usage = self.predict( - prompt, - context=datum["context"], - question=datum["question"], - ) - # TODO check if answer is lower-cased in metric computation - - evaluation_usage += usage - - num_samples += 1 - result = self.metric.compute( - predictions=[{"prediction_text": answer, "id": datum["id"]}], - references=[{"answers": datum["answers"], "id": datum["id"]}], - ) - f1 += result["f1"] - em += result["exact_match"] - - dataset_iterator.set_postfix( - {"f1": f1 / num_samples, "em": em / num_samples} - ) - - return f1 / num_samples, evaluation_usage + return super().evaluate(prompt, dataset) @property def metric_name(self): @@ -319,3 +433,20 @@ class QuestionAnswering(Task): def base_prompt(self): # TODO find good prompt return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + + +tasks = {task.shorthand: task for task in Task.__subclasses__()} + + +def get_task(name: str, evaluation_model: LLMModel, options: Namespace): + if name not in tasks: + raise ValueError("Model %s does not exist", name) + return tasks[name](evaluation_model, options) + + +argument_parser.add_argument("--debug", "-d", action="store_true", default=None) +argument_group = argument_parser.add_argument_group("Task arguments") +argument_group.add_argument("--use-grammar", "-g", action="store_true") +argument_group.add_argument( + "--task", "-t", type=str, required=True, choices=["sa", "qa"] +)