Skip to content
Snippets Groups Projects
Commit c954dded authored by Max Kimmich's avatar Max Kimmich
Browse files

Further refactor class Task and make sure that grammar is only build if used

parent 9f5e2581
No related branches found
No related tags found
1 merge request!7Refactor tasks and models and fix format for various models
...@@ -374,21 +374,17 @@ class Task(metaclass=ABCMeta): ...@@ -374,21 +374,17 @@ class Task(metaclass=ABCMeta):
prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt) prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt)
for datum in dataset_iterator: for datum in dataset_iterator:
# build prompt for current sample
prompt_for_datum = self.build_prompt_input(datum, prompt=prompt_with_examples, use_prediction_prefix=self.model._get_prediction_prefix() is None)
# input(f"Prompt for datum:\n{prompt_for_datum}")
# run prediction # run prediction
response, usage = self.predict(prompt=prompt_for_datum, grammar=self._get_grammar(datum)) response, usage = self.predict(prompt=prompt_with_examples, datum=datum)
# input(f"Response: '{response}'") logger.debug(f"Response: '{response}'")
# parse response # parse response
response = self._parse_response(response=response) response = self._parse_response(response=response)
# input(f"Parsed response: {response}") logger.debug(f"Parsed response: '{response}'")
# evaluate response # evaluate response
result = self._evaluate_sample(response=response, datum=datum) result = self._evaluate_sample(response=response, datum=datum)
# input( logger.debug(
# f"Prediction: {response}, Gold label: {self._get_gold_label_generation_for_datum(datum)}, Result: {result}" f"Prediction: '{response}', Gold label: '{self._get_gold_label_generation_for_datum(datum)}', Result: {result}"
# ) )
results.append(result) results.append(result)
current_metric = self._aggregate_result(results) current_metric = self._aggregate_result(results)
dataset_iterator.set_postfix({self.metric_name: f"{current_metric:.2f}"}) dataset_iterator.set_postfix({self.metric_name: f"{current_metric:.2f}"})
...@@ -405,14 +401,17 @@ class Task(metaclass=ABCMeta): ...@@ -405,14 +401,17 @@ class Task(metaclass=ABCMeta):
return self._aggregate_result(results), evaluation_usage, evaluation_history return self._aggregate_result(results), evaluation_usage, evaluation_history
@weave.op() @weave.op()
def predict(self, prompt: str, grammar: LlamaGrammar) -> tuple[str, ModelUsage]: def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
# run model for inference using grammar to constrain output # run model for inference using grammar to constrain output
# TODO grammar also depends on prompt and vice-versa -> what are good labels? # TODO grammar also depends on prompt and vice-versa -> what are good labels?
# build prompt for current sample
prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None)
logger.debug(f"Prompt for datum:\n{prompt_for_datum}")
response, _, _, usage = self.model.create_completion( response, _, _, usage = self.model.create_completion(
system_message=SYSTEM_MESSAGE, system_message=SYSTEM_MESSAGE,
prompt=prompt, prompt=prompt_for_datum,
# grammar can be applied to constrain the model output # grammar can be applied to constrain the model output
grammar=grammar if self.use_grammar else None, grammar=self._get_grammar(datum) if self.use_grammar else None,
# we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result
use_cache=True, use_cache=True,
# use less randomness, i.e., more certain outputs # use less randomness, i.e., more certain outputs
...@@ -451,8 +450,7 @@ class Task(metaclass=ABCMeta): ...@@ -451,8 +450,7 @@ class Task(metaclass=ABCMeta):
def _get_prediction_prefix() -> str: ... def _get_prediction_prefix() -> str: ...
@abstractmethod @abstractmethod
def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: ...
pass
@abstractmethod @abstractmethod
def _evaluate_sample(self, response: str, datum: DatasetDatum) -> Any: ... def _evaluate_sample(self, response: str, datum: DatasetDatum) -> Any: ...
...@@ -462,19 +460,15 @@ class Task(metaclass=ABCMeta): ...@@ -462,19 +460,15 @@ class Task(metaclass=ABCMeta):
@abstractmethod @abstractmethod
# This method is needed for the demonstration examples. # This method is needed for the demonstration examples.
def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: ...
pass
@abstractmethod @abstractmethod
def _aggregate_result(self, results: list) -> float: def _aggregate_result(self, results: list) -> float: ...
pass
@property @property
@abstractmethod @abstractmethod
def metric_name(self) -> str: def metric_name(self) -> str: ...
pass
@property @property
@abstractmethod @abstractmethod
def base_prompts(self) -> list[str]: def base_prompts(self) -> list[str]: ...
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment