diff --git a/cli.py b/cli.py
index 3b1806313217075d690acade370d9a185fb12f2f..ab05d658a20c4392677707429d9265eed26cb2e0 100644
--- a/cli.py
+++ b/cli.py
@@ -12,5 +12,6 @@ argument_parser.add_argument("--model-path", "-m", type=str, required=True)
 argument_parser.add_argument(
     "--task", "-t", type=str, required=True, choices=["sa", "qa"]
 )
+argument_parser.add_argument("--use-grammar", "-g", action="store_true")
 argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
 argument_parser.add_argument("--chat", "-c", action="store_true")
diff --git a/main.py b/main.py
index 13e4280cd627f8b9afe2380d3db451476f174ec2..32633ae54813e95ed6abcb3aa4eafdacad37982e 100644
--- a/main.py
+++ b/main.py
@@ -318,6 +318,7 @@ if __name__ == "__main__":
                 evaluation_model,
                 "SetFit/sst2",
                 "SetFit/sst2",
+                use_grammar=options.use_grammar,
                 validation_split=f"validation[:{5 if debug else 200}]",
                 test_split="test[:20]" if debug else "test",
             )
@@ -327,6 +328,7 @@ if __name__ == "__main__":
                 evaluation_model,
                 "squad",
                 "squad",
+                use_grammar=options.use_grammar,
                 validation_split=f"train[:{5 if debug else 200}]",
                 test_split="validation[:20]" if debug else "validation",
             )
diff --git a/task.py b/task.py
index 1a1884854ad85dc48772a992a33eb901d8950c7a..97fee81ef835721eb68250ae2fda21f3ce0db21a 100644
--- a/task.py
+++ b/task.py
@@ -31,9 +31,10 @@ class Task:
         model: Union[Llama2, OpenAI],
         validation_dataset: str,
         test_dataset: str,
+        *,
+        use_grammar: bool,
         validation_split: str = None,
         test_split: str = None,
-        use_grammar: bool = True,
     ) -> None:
         self.model = model
 
@@ -86,22 +87,24 @@ def sa_grammar_fn(verbose: bool = False):
 
 
 class SentimentAnalysis(Task):
+
     def __init__(
         self,
         model,
         validation_dataset: str,
         test_dataset: str,
+        *,
+        use_grammar: bool,
         validation_split: str = None,
         test_split: str = None,
-        use_grammar: bool = True,
     ) -> None:
         super().__init__(
             model,
             validation_dataset,
             test_dataset,
-            validation_split,
-            test_split,
-            use_grammar,
+            use_grammar=use_grammar,
+            validation_split=validation_split,
+            test_split=test_split,
         )
 
     def predict(self, prompt: str, text: str):
@@ -196,14 +199,16 @@ def qa_grammar_fn(context: str, verbose: bool = False):
 
 
 class QuestionAnswering(Task):
+
     def __init__(
         self,
         model,
         validation_dataset: str,
         test_dataset: str,
+        *,
+        use_grammar: bool,
         validation_split: str = None,
         test_split: str = None,
-        use_grammar: bool = True,
     ) -> None:
         self.evaluation_prompt = """
         Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -226,9 +231,9 @@ class QuestionAnswering(Task):
             model,
             validation_dataset,
             test_dataset,
-            validation_split,
-            test_split,
-            use_grammar,
+            use_grammar=use_grammar,
+            validation_split=validation_split,
+            test_split=test_split,
         )
 
     def predict(self, prompt: str, context: str, question: str):
diff --git a/utils.py b/utils.py
index 88145912bf0131e152a65cc478e438f432e4b5a7..6dcd295524f8b11f1a377c3e69a13a8c01a4f091 100644
--- a/utils.py
+++ b/utils.py
@@ -146,6 +146,7 @@ def save_snapshot(
                     "validation_dataset": task.validation_dataset.info.dataset_name,
                     "test_dataset": task.test_dataset.info.dataset_name,
                     "metric": task.metric_name,
+                    "use_grammar": task.use_grammar,
                 },
                 "model": {"name": model.__class__.__name__},
                 "run_options": run_options,