From c954ddedc79cbc6186dc87c7bfbc32b79c80cf59 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 1 Oct 2024 16:16:23 +0200
Subject: [PATCH] Further refactor class Task and make sure that grammar is
 only build if used

---
 evoprompt/task/task.py | 40 +++++++++++++++++-----------------------
 1 file changed, 17 insertions(+), 23 deletions(-)

diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 85f7e72..cfac945 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -374,21 +374,17 @@ class Task(metaclass=ABCMeta):
         prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt)
 
         for datum in dataset_iterator:
-            # build prompt for current sample
-
-            prompt_for_datum = self.build_prompt_input(datum, prompt=prompt_with_examples, use_prediction_prefix=self.model._get_prediction_prefix() is None)
-            # input(f"Prompt for datum:\n{prompt_for_datum}")
             # run prediction
-            response, usage = self.predict(prompt=prompt_for_datum, grammar=self._get_grammar(datum))
-            # input(f"Response: '{response}'")
+            response, usage = self.predict(prompt=prompt_with_examples, datum=datum)
+            logger.debug(f"Response: '{response}'")
             # parse response
             response = self._parse_response(response=response)
-            # input(f"Parsed response: {response}")
+            logger.debug(f"Parsed response: '{response}'")
             # evaluate response
             result = self._evaluate_sample(response=response, datum=datum)
-            # input(
-            #     f"Prediction: {response}, Gold label: {self._get_gold_label_generation_for_datum(datum)}, Result: {result}"
-            # )
+            logger.debug(
+                f"Prediction: '{response}', Gold label: '{self._get_gold_label_generation_for_datum(datum)}', Result: {result}"
+            )
             results.append(result)
             current_metric = self._aggregate_result(results)
             dataset_iterator.set_postfix({self.metric_name: f"{current_metric:.2f}"})
@@ -405,14 +401,17 @@ class Task(metaclass=ABCMeta):
         return self._aggregate_result(results), evaluation_usage, evaluation_history
     
     @weave.op()
-    def predict(self, prompt: str, grammar: LlamaGrammar) -> tuple[str, ModelUsage]:
+    def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
+        # build prompt for current sample
+        prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None)
+        logger.debug(f"Prompt for datum:\n{prompt_for_datum}")
         response, _, _, usage = self.model.create_completion(
             system_message=SYSTEM_MESSAGE,
-            prompt=prompt,
+            prompt=prompt_for_datum,
             # grammar can be applied to constrain the model output
-            grammar=grammar if self.use_grammar else None,
+            grammar=self._get_grammar(datum) if self.use_grammar else None,
             # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result
             use_cache=True,
             # use less randomness, i.e., more certain outputs
@@ -451,8 +450,7 @@ class Task(metaclass=ABCMeta):
     def _get_prediction_prefix() -> str: ...
 
     @abstractmethod
-    def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
-        pass
+    def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: ...
 
     @abstractmethod
     def _evaluate_sample(self, response: str, datum: DatasetDatum) -> Any: ...
@@ -462,19 +460,15 @@ class Task(metaclass=ABCMeta):
 
     @abstractmethod
     # This method is needed for the demonstration examples.
-    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
-        pass
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: ...
 
     @abstractmethod
-    def _aggregate_result(self, results: list) -> float:
-        pass
+    def _aggregate_result(self, results: list) -> float: ...
 
     @property
     @abstractmethod
-    def metric_name(self) -> str:
-        pass
+    def metric_name(self) -> str: ...
 
     @property
     @abstractmethod
-    def base_prompts(self) -> list[str]:
-        pass
+    def base_prompts(self) -> list[str]: ...
-- 
GitLab