diff --git a/cli.py b/cli.py index 3b1806313217075d690acade370d9a185fb12f2f..ab05d658a20c4392677707429d9265eed26cb2e0 100644 --- a/cli.py +++ b/cli.py @@ -12,5 +12,6 @@ argument_parser.add_argument("--model-path", "-m", type=str, required=True) argument_parser.add_argument( "--task", "-t", type=str, required=True, choices=["sa", "qa"] ) +argument_parser.add_argument("--use-grammar", "-g", action="store_true") argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_parser.add_argument("--chat", "-c", action="store_true") diff --git a/main.py b/main.py index 13e4280cd627f8b9afe2380d3db451476f174ec2..32633ae54813e95ed6abcb3aa4eafdacad37982e 100644 --- a/main.py +++ b/main.py @@ -318,6 +318,7 @@ if __name__ == "__main__": evaluation_model, "SetFit/sst2", "SetFit/sst2", + use_grammar=options.use_grammar, validation_split=f"validation[:{5 if debug else 200}]", test_split="test[:20]" if debug else "test", ) @@ -327,6 +328,7 @@ if __name__ == "__main__": evaluation_model, "squad", "squad", + use_grammar=options.use_grammar, validation_split=f"train[:{5 if debug else 200}]", test_split="validation[:20]" if debug else "validation", ) diff --git a/task.py b/task.py index 1a1884854ad85dc48772a992a33eb901d8950c7a..97fee81ef835721eb68250ae2fda21f3ce0db21a 100644 --- a/task.py +++ b/task.py @@ -31,9 +31,10 @@ class Task: model: Union[Llama2, OpenAI], validation_dataset: str, test_dataset: str, + *, + use_grammar: bool, validation_split: str = None, test_split: str = None, - use_grammar: bool = True, ) -> None: self.model = model @@ -86,22 +87,24 @@ def sa_grammar_fn(verbose: bool = False): class SentimentAnalysis(Task): + def __init__( self, model, validation_dataset: str, test_dataset: str, + *, + use_grammar: bool, validation_split: str = None, test_split: str = None, - use_grammar: bool = True, ) -> None: super().__init__( model, validation_dataset, test_dataset, - validation_split, - test_split, - use_grammar, + use_grammar=use_grammar, + validation_split=validation_split, + test_split=test_split, ) def predict(self, prompt: str, text: str): @@ -196,14 +199,16 @@ def qa_grammar_fn(context: str, verbose: bool = False): class QuestionAnswering(Task): + def __init__( self, model, validation_dataset: str, test_dataset: str, + *, + use_grammar: bool, validation_split: str = None, test_split: str = None, - use_grammar: bool = True, ) -> None: self.evaluation_prompt = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -226,9 +231,9 @@ class QuestionAnswering(Task): model, validation_dataset, test_dataset, - validation_split, - test_split, - use_grammar, + use_grammar=use_grammar, + validation_split=validation_split, + test_split=test_split, ) def predict(self, prompt: str, context: str, question: str): diff --git a/utils.py b/utils.py index 88145912bf0131e152a65cc478e438f432e4b5a7..6dcd295524f8b11f1a377c3e69a13a8c01a4f091 100644 --- a/utils.py +++ b/utils.py @@ -146,6 +146,7 @@ def save_snapshot( "validation_dataset": task.validation_dataset.info.dataset_name, "test_dataset": task.test_dataset.info.dataset_name, "metric": task.metric_name, + "use_grammar": task.use_grammar, }, "model": {"name": model.__class__.__name__}, "run_options": run_options,