From 4c276ebc6bfb42a12b0cdcd15f343574dced3a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Wed, 13 Mar 2024 15:27:58 +0100 Subject: [PATCH] add --use-grammar flag to diable grammars in tasks --- cli.py | 1 + main.py | 2 ++ task.py | 23 ++++++++++++++--------- utils.py | 1 + 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/cli.py b/cli.py index 3b18063..ab05d65 100644 --- a/cli.py +++ b/cli.py @@ -12,5 +12,6 @@ argument_parser.add_argument("--model-path", "-m", type=str, required=True) argument_parser.add_argument( "--task", "-t", type=str, required=True, choices=["sa", "qa"] ) +argument_parser.add_argument("--use-grammar", "-g", action="store_true") argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_parser.add_argument("--chat", "-c", action="store_true") diff --git a/main.py b/main.py index 13e4280..32633ae 100644 --- a/main.py +++ b/main.py @@ -318,6 +318,7 @@ if __name__ == "__main__": evaluation_model, "SetFit/sst2", "SetFit/sst2", + use_grammar=options.use_grammar, validation_split=f"validation[:{5 if debug else 200}]", test_split="test[:20]" if debug else "test", ) @@ -327,6 +328,7 @@ if __name__ == "__main__": evaluation_model, "squad", "squad", + use_grammar=options.use_grammar, validation_split=f"train[:{5 if debug else 200}]", test_split="validation[:20]" if debug else "validation", ) diff --git a/task.py b/task.py index 1a18848..97fee81 100644 --- a/task.py +++ b/task.py @@ -31,9 +31,10 @@ class Task: model: Union[Llama2, OpenAI], validation_dataset: str, test_dataset: str, + *, + use_grammar: bool, validation_split: str = None, test_split: str = None, - use_grammar: bool = True, ) -> None: self.model = model @@ -86,22 +87,24 @@ def sa_grammar_fn(verbose: bool = False): class SentimentAnalysis(Task): + def __init__( self, model, validation_dataset: str, test_dataset: str, + *, + use_grammar: bool, validation_split: str = None, test_split: str = None, - use_grammar: bool = True, ) -> None: super().__init__( model, validation_dataset, test_dataset, - validation_split, - test_split, - use_grammar, + use_grammar=use_grammar, + validation_split=validation_split, + test_split=test_split, ) def predict(self, prompt: str, text: str): @@ -196,14 +199,16 @@ def qa_grammar_fn(context: str, verbose: bool = False): class QuestionAnswering(Task): + def __init__( self, model, validation_dataset: str, test_dataset: str, + *, + use_grammar: bool, validation_split: str = None, test_split: str = None, - use_grammar: bool = True, ) -> None: self.evaluation_prompt = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -226,9 +231,9 @@ class QuestionAnswering(Task): model, validation_dataset, test_dataset, - validation_split, - test_split, - use_grammar, + use_grammar=use_grammar, + validation_split=validation_split, + test_split=test_split, ) def predict(self, prompt: str, context: str, question: str): diff --git a/utils.py b/utils.py index 8814591..6dcd295 100644 --- a/utils.py +++ b/utils.py @@ -146,6 +146,7 @@ def save_snapshot( "validation_dataset": task.validation_dataset.info.dataset_name, "test_dataset": task.test_dataset.info.dataset_name, "metric": task.metric_name, + "use_grammar": task.use_grammar, }, "model": {"name": model.__class__.__name__}, "run_options": run_options, -- GitLab