Skip to content
Snippets Groups Projects

Refactor tasks so that tasks have more control over model input and output

Merged Max Kimmich requested to merge refactor-task into master
11 files
+ 144
96
Compare changes
  • Side-by-side
  • Inline
Files
11
import logging
import re
from abc import abstractmethod
from functools import cache, lru_cache
from functools import lru_cache
from typing import Iterable
from datasets import Dataset
@@ -55,10 +55,9 @@ class QuestionAnswering(Task):
self.metric = load_metric("squad")
@cache
def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
# context-sensitive grammar
context = self._get_context_from_datum(datum)
context = self._get_context_for_datum(datum)
try:
return extractive_qa_grammar_fn(context)
except Exception as e:
@@ -69,12 +68,14 @@ class QuestionAnswering(Task):
exc_info=e,
)
@staticmethod
def _get_generation_prefix():
return "Answer: "
def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
context = self._get_context_from_datum(datum)
question = self._get_question_from_datum(datum)
return (
"\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"'
)
context = self._get_context_for_datum(datum)
question = self._get_question_for_datum(datum)
return f"Context: {context}\nQuestion: {question}"
def _get_demonstration_sample_ids(
self, dataset: Dataset, n_evaluation_demo: int
@@ -87,39 +88,41 @@ class QuestionAnswering(Task):
pass
@abstractmethod
def _get_context_from_datum(self, datum: DatasetDatum):
def _get_context_for_datum(self, datum: DatasetDatum):
pass
@abstractmethod
def _get_question_from_datum(self, datum: DatasetDatum):
def _get_question_for_datum(self, datum: DatasetDatum):
pass
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
def _parse_response(self, response: str) -> str:
if self.use_grammar:
return response
# if we do not use a grammar, we need to extract the answer from the response
# otherwise the answer is from the context as enforced by the grammar
prefix_to_match = self._get_generation_prefix().replace(" ", r"\s?")
matches = re.findall(
# regex that matches class labels after "Response: "
rf"(?:{prefix_to_match})?(.+)",
response.splitlines()[-1],
flags=re.IGNORECASE,
)
# look for an answer in the response, if not found, use whole response
if matches:
return matches[-1]
else:
return response
def _evaluate_sample(self, response: str, datum: DatasetDatum) -> float:
_id = self._get_id_from_datum(datum)
gold_answers = self._get_gold_label_for_datum(datum)
response, usage = self.predict(prompt, datum)
response = response.lower()
if not self.use_grammar:
# if we do not use a grammar, we need to extract the answer from the response
# otherwise the answer is from the context as enforced by the grammar
matches = re.findall(
# regex that matches class labels after "Response: "
rf"(?:Response:\s?)?(.+)",
response.splitlines()[-1],
flags=re.IGNORECASE,
)
# look for an answer in the response, if not found, use whole response
if matches:
answer = matches[-1]
else:
answer = response
result = self.metric.compute(
predictions=[{"prediction_text": answer, "id": _id}],
predictions=[{"prediction_text": response, "id": _id}],
references=[{"answers": gold_answers, "id": _id}],
)
return result["f1"] / 100, usage
return result["f1"]
def _aggregate_result(self, results: list[float]) -> float:
return sum(results) / len(results)
@@ -179,10 +182,10 @@ class SQuAD(QuestionAnswering):
def load_test_set(self, test_dataset: str, test_split: str | None):
return super().load_test_set("squad", "validation")
def _get_context_from_datum(self, datum: DatasetDatum):
def _get_context_for_datum(self, datum: DatasetDatum):
return datum["context"]
def _get_question_from_datum(self, datum: DatasetDatum):
def _get_question_for_datum(self, datum: DatasetDatum):
return datum["question"]
def _get_id_from_datum(self, datum: DatasetDatum):
Loading