From c954ddedc79cbc6186dc87c7bfbc32b79c80cf59 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 1 Oct 2024 16:16:23 +0200 Subject: [PATCH] Further refactor class Task and make sure that grammar is only build if used --- evoprompt/task/task.py | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 85f7e72..cfac945 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -374,21 +374,17 @@ class Task(metaclass=ABCMeta): prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt) for datum in dataset_iterator: - # build prompt for current sample - - prompt_for_datum = self.build_prompt_input(datum, prompt=prompt_with_examples, use_prediction_prefix=self.model._get_prediction_prefix() is None) - # input(f"Prompt for datum:\n{prompt_for_datum}") # run prediction - response, usage = self.predict(prompt=prompt_for_datum, grammar=self._get_grammar(datum)) - # input(f"Response: '{response}'") + response, usage = self.predict(prompt=prompt_with_examples, datum=datum) + logger.debug(f"Response: '{response}'") # parse response response = self._parse_response(response=response) - # input(f"Parsed response: {response}") + logger.debug(f"Parsed response: '{response}'") # evaluate response result = self._evaluate_sample(response=response, datum=datum) - # input( - # f"Prediction: {response}, Gold label: {self._get_gold_label_generation_for_datum(datum)}, Result: {result}" - # ) + logger.debug( + f"Prediction: '{response}', Gold label: '{self._get_gold_label_generation_for_datum(datum)}', Result: {result}" + ) results.append(result) current_metric = self._aggregate_result(results) dataset_iterator.set_postfix({self.metric_name: f"{current_metric:.2f}"}) @@ -405,14 +401,17 @@ class Task(metaclass=ABCMeta): return self._aggregate_result(results), evaluation_usage, evaluation_history @weave.op() - def predict(self, prompt: str, grammar: LlamaGrammar) -> tuple[str, ModelUsage]: + def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]: # run model for inference using grammar to constrain output # TODO grammar also depends on prompt and vice-versa -> what are good labels? + # build prompt for current sample + prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None) + logger.debug(f"Prompt for datum:\n{prompt_for_datum}") response, _, _, usage = self.model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=prompt, + prompt=prompt_for_datum, # grammar can be applied to constrain the model output - grammar=grammar if self.use_grammar else None, + grammar=self._get_grammar(datum) if self.use_grammar else None, # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result use_cache=True, # use less randomness, i.e., more certain outputs @@ -451,8 +450,7 @@ class Task(metaclass=ABCMeta): def _get_prediction_prefix() -> str: ... @abstractmethod - def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: - pass + def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: ... @abstractmethod def _evaluate_sample(self, response: str, datum: DatasetDatum) -> Any: ... @@ -462,19 +460,15 @@ class Task(metaclass=ABCMeta): @abstractmethod # This method is needed for the demonstration examples. - def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: - pass + def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: ... @abstractmethod - def _aggregate_result(self, results: list) -> float: - pass + def _aggregate_result(self, results: list) -> float: ... @property @abstractmethod - def metric_name(self) -> str: - pass + def metric_name(self) -> str: ... @property @abstractmethod - def base_prompts(self) -> list[str]: - pass + def base_prompts(self) -> list[str]: ... -- GitLab