Skip to content
Snippets Groups Projects
Commit c954dded authored by Max Kimmich's avatar Max Kimmich
Browse files

Further refactor class Task and make sure that grammar is only build if used

parent 9f5e2581
Loading
......@@ -374,21 +374,17 @@ class Task(metaclass=ABCMeta):
prompt_with_examples = self.build_demonstration_prompt(self.demonstration_samples, prompt=prompt)
for datum in dataset_iterator:
# build prompt for current sample
prompt_for_datum = self.build_prompt_input(datum, prompt=prompt_with_examples, use_prediction_prefix=self.model._get_prediction_prefix() is None)
# input(f"Prompt for datum:\n{prompt_for_datum}")
# run prediction
response, usage = self.predict(prompt=prompt_for_datum, grammar=self._get_grammar(datum))
# input(f"Response: '{response}'")
response, usage = self.predict(prompt=prompt_with_examples, datum=datum)
logger.debug(f"Response: '{response}'")
# parse response
response = self._parse_response(response=response)
# input(f"Parsed response: {response}")
logger.debug(f"Parsed response: '{response}'")
# evaluate response
result = self._evaluate_sample(response=response, datum=datum)
# input(
# f"Prediction: {response}, Gold label: {self._get_gold_label_generation_for_datum(datum)}, Result: {result}"
# )
logger.debug(
f"Prediction: '{response}', Gold label: '{self._get_gold_label_generation_for_datum(datum)}', Result: {result}"
)
results.append(result)
current_metric = self._aggregate_result(results)
dataset_iterator.set_postfix({self.metric_name: f"{current_metric:.2f}"})
......@@ -405,14 +401,17 @@ class Task(metaclass=ABCMeta):
return self._aggregate_result(results), evaluation_usage, evaluation_history
@weave.op()
def predict(self, prompt: str, grammar: LlamaGrammar) -> tuple[str, ModelUsage]:
def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
# run model for inference using grammar to constrain output
# TODO grammar also depends on prompt and vice-versa -> what are good labels?
# build prompt for current sample
prompt_for_datum = self.build_prompt_input(datum, prompt=prompt, use_prediction_prefix=self.model._get_prediction_prefix() is None)
logger.debug(f"Prompt for datum:\n{prompt_for_datum}")
response, _, _, usage = self.model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=prompt,
prompt=prompt_for_datum,
# grammar can be applied to constrain the model output
grammar=grammar if self.use_grammar else None,
grammar=self._get_grammar(datum) if self.use_grammar else None,
# we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result
use_cache=True,
# use less randomness, i.e., more certain outputs
......@@ -451,8 +450,7 @@ class Task(metaclass=ABCMeta):
def _get_prediction_prefix() -> str: ...
@abstractmethod
def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
pass
def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: ...
@abstractmethod
def _evaluate_sample(self, response: str, datum: DatasetDatum) -> Any: ...
......@@ -462,19 +460,15 @@ class Task(metaclass=ABCMeta):
@abstractmethod
# This method is needed for the demonstration examples.
def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
pass
def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: ...
@abstractmethod
def _aggregate_result(self, results: list) -> float:
pass
def _aggregate_result(self, results: list) -> float: ...
@property
@abstractmethod
def metric_name(self) -> str:
pass
def metric_name(self) -> str: ...
@property
@abstractmethod
def base_prompts(self) -> list[str]:
pass
def base_prompts(self) -> list[str]: ...
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment