Skip to content
Snippets Groups Projects
Commit c27d342b authored by Max Kimmich's avatar Max Kimmich
Browse files

Refactor tasks

parent 2f2fd458
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ from evoprompt.models import LLMModel
from evoprompt.opt_types import ModelUsage, Prompt
from evoprompt.optimization import PromptOptimization
from evoprompt.task import Task
from evoprompt.utils import log_calls
from evoprompt.utils import log_calls, get_all_subclasses
logger = logging.getLogger(__name__)
......
from argparse import Namespace
from evoprompt.cli import argument_parser
from evoprompt.models import LLMModel
from evoprompt.task.question_answering import QuestionAnswering
from evoprompt.task.sentiment_analysis import SentimentAnalysis
# make sure to run definitions of subclasses of Task first
from evoprompt.task.task import Task
from evoprompt.utils import get_all_subclasses
# at this point, we assume that all subclasses of Task have been defined
tasks = {
task.shorthand: task
for task in get_all_subclasses(Task)
if hasattr(task, "shorthand")
}
def get_task(name: str, evaluation_model: LLMModel, **options):
if name not in tasks:
raise ValueError("Model %s does not exist", name)
return tasks[name](evaluation_model, **options)
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_group = argument_parser.add_argument_group("Task arguments")
argument_group.add_argument("--use-grammar", "-g", action="store_true")
argument_group.add_argument("--no-early-stopping", action="store_true")
argument_group.add_argument("--evaluate-shortest-first", action="store_true")
argument_group.add_argument(
"--task", "-t", type=str, required=True, choices=tasks.keys()
)
import logging
import re
from argparse import Namespace
from functools import lru_cache
from datasets import Dataset
from evaluate import load as load_metric
from llama_cpp import LlamaGrammar
from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task
logger = logging.getLogger(__name__)
def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'):
if len(sequence) == 0:
return ""
if quote is None:
quote = ""
escape_item = lambda text: text.replace('"', '\\"')
grammar = quote + escape_item(sequence[0]) + quote
for idx, item in enumerate(sequence[1:]):
new_item = "(" + quote + escape_item(item) + quote + ")?"
if idx == 0:
grammar += new_item
else:
grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :]
return grammar
# a grammar in GBNF notation which takes the context into account
# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
@lru_cache
def extractive_qa_grammar_fn(context: str, verbose: bool = False):
grammar_str = "(\n"
tokens = re.split(r"(\W+)", context.strip())[:-1]
for i in range(len(tokens)):
if i > 0:
grammar_str += "\n|\n"
grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")"
grammar_str += "\n)"
return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose)
class QuestionAnswering(Task):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.metric = load_metric("squad")
def predict(self, prompt: str, question: str, context: str):
# run model for inference
grammar = None
if self.use_grammar:
# context-sensitive grammar
context = context
try:
grammar = extractive_qa_grammar_fn(context)
except Exception as e:
logger.exception(
"Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s",
len(context),
context,
exc_info=e,
)
response, usage = self.model(
system_message=SYSTEM_MESSAGE,
prompt=prompt,
prompt_appendix=self._get_text_for_datum(question, context),
grammar=grammar,
)
if not self.use_grammar:
# we postprocess the model output to return as answer
response = response.strip()
return response, usage
def _get_text_for_datum(self, question: str, context: str) -> str:
return (
"\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"'
)
def _evaluate_qa_sample(
self, prompt: str, id: str, question, context, gold_answers: list[str]
):
answer, usage = self.predict(prompt, question, context)
# input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}')
# TODO check if answer is lower-cased in metric computation
result = self.metric.compute(
predictions=[{"prediction_text": answer, "id": id}],
references=[{"answers": gold_answers, "id": id}],
)
return result["f1"] / 100, usage
def _aggregate_result(self, results: list[float]) -> float:
return sum(results) / len(results)
def evaluate(
self,
prompt: str,
dataset: Dataset,
parent_histories: list[list[float]] | None = None,
):
def replace_symbol_for_grammar(sample: DatasetDatum):
symbol_replacement_mapping = {
"\u2013": "-",
"\u2014": "-",
}
symbol_replacement_mapping = dict(
(re.escape(k), v) for k, v in symbol_replacement_mapping.items()
)
symbol_replacement_pattern = re.compile(
"|".join(symbol_replacement_mapping.keys())
)
replace_fn = lambda text: symbol_replacement_pattern.sub(
lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text
)
sample["context"] = replace_fn(sample["context"])
sample["answers"]["text"] = [
replace_fn(text) for text in sample["answers"]["text"]
]
return sample
if self.use_grammar:
# NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen)
dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols")
return super().evaluate(prompt, dataset, parent_histories=parent_histories)
@property
def metric_name(self):
return "f1"
@property
def base_prompt(self):
# TODO find good prompt
return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
class SQuAD(QuestionAnswering):
shorthand = "squad"
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
):
return super().load_validation_set("squad", "train[:200]")
def load_test_set(self, test_dataset: str, test_split: str | None):
return super().load_test_set("squad", "validation")
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
return self._evaluate_qa_sample(
prompt, datum["id"], datum["question"], datum["context"], datum["answers"]
)
import logging
from argparse import Namespace
from functools import lru_cache
from llama_cpp import LlamaGrammar
from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task
logger = logging.getLogger(__name__)
# a simple grammar in GBNF notation which allows either positive or negative
# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
SA_OUTPUTS = set(("positive", "negative"))
@lru_cache
def sa_grammar_fn(verbose: bool = False):
return LlamaGrammar.from_string(
"root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)),
verbose=verbose,
)
class SentimentAnalysis(Task):
shorthand = "sa"
def predict(self, prompt: str, text: str):
# run model for inference using grammar to constrain output
# TODO grammar also depends on prompt and vice-versa -> what are good labels?
response, usage = self.model(
system_message=SYSTEM_MESSAGE,
prompt=prompt,
prompt_appendix="\nInput: " + '"' + text + '"',
grammar=sa_grammar_fn() if self.use_grammar else None,
)
if not self.use_grammar:
# we postprocess the model output to return as answer
response = response.strip()
# input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}")
return response, usage
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
return datum["text"]
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
sst2_labels = {"negative": 0, "positive": 1}
response, usage = self.predict(
prompt=prompt, text=self._get_text_for_datum(datum)
)
response = response.lower()
if self.use_grammar:
# model output is from label space
answer_label = sst2_labels[response]
else:
answer_label = None
for label in sst2_labels.keys():
if label in response:
answer_label = sst2_labels[label]
break
else:
logger.warning(f"Invalid answer: {response}")
return "failed", usage
classification_result = (
"incorrect" if answer_label != datum["label"] else "correct"
)
return classification_result, usage
def _aggregate_result(self, results: list[str]) -> float:
num_correct_results = sum(1 for result in results if result == "correct")
accuracy = num_correct_results / len(results)
return accuracy
@property
def metric_name(self):
return "accuracy"
@property
def base_prompt(self):
# from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text."""
class SST2(SentimentAnalysis):
shorthand = "sa"
def load_validation_set(
self, validation_dataset: str | None, validation_split: str | None
):
return super().load_validation_set("SetFit/sst2", "validation[:200]")
def load_test_set(self, test_dataset: str, test_split: str | None):
return super().load_test_set("SetFit/sst2", "test")
import logging
import re
from abc import abstractmethod
from abc import ABCMeta, abstractmethod
from argparse import Namespace
from functools import lru_cache
from statistics import mean
from typing import Union
from datasets import Dataset, load_dataset
from evaluate import load as load_metric
from llama_cpp import LlamaGrammar, deque
from tqdm import tqdm
from evoprompt.cli import argument_parser
from evoprompt.models import LLMModel
from evoprompt.opt_types import ModelUsage
from evoprompt.utils import log_calls
......@@ -114,7 +112,7 @@ class ParentBaselineBasedStopping(EarlyStoppingMonitor):
return False
class Task:
class Task(metaclass=ABCMeta):
shorthand: str
validation_dataset: Dataset
test_dataset: Dataset
......@@ -122,17 +120,19 @@ class Task:
def __init__(
self,
model: Union[LLMModel],
validation_dataset: str,
test_dataset: str,
validation_dataset: str | None = None,
test_dataset: str | None = None,
*,
use_grammar: bool,
no_early_stopping: bool = False,
evaluate_shortest_first: bool = False,
validation_split: str | None = None,
test_split: str | None = None,
debug: bool = False,
**kwargs,
) -> None:
self.model = model
self.debug = debug
# whether we use the grammar to constrain the model output or not
self.use_grammar = use_grammar
self.no_early_stopping = no_early_stopping
......@@ -141,10 +141,24 @@ class Task:
logger.warning(
"Both 'evaluate_shortest_first' and 'no_early_stopping' are set. This makes no sense"
)
self.validation_dataset = load_dataset(
validation_dataset, split=validation_split
self.validation_dataset = self.load_validation_set(
validation_dataset, validation_split
)
self.test_dataset = load_dataset(test_dataset, split=test_split)
self.test_dataset = self.load_test_set(test_dataset, test_split)
def load_validation_set(
self, validation_dataset: str, validation_split: str | None
):
dataset = load_dataset(validation_dataset, split=validation_split)
if self.debug and len(dataset) > 5:
dataset = dataset.select(range(5))
return dataset
def load_test_set(self, test_dataset: str, test_split: str | None):
dataset = load_dataset(test_dataset, split=test_split)
if self.debug and len(dataset) > 20:
dataset = dataset.select(range(20))
return dataset
@abstractmethod
def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]:
......@@ -229,259 +243,3 @@ class Task:
@abstractmethod
def base_prompt(self) -> str:
pass
# a simple grammar in GBNF notation which allows either positive or negative
# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
SA_OUTPUTS = set(("positive", "negative"))
@lru_cache
def sa_grammar_fn(verbose: bool = False):
return LlamaGrammar.from_string(
"root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)),
verbose=verbose,
)
class SentimentAnalysis(Task):
shorthand = "sa"
def __init__(self, model, options: Namespace):
super().__init__(
model,
validation_dataset="SetFit/sst2",
test_dataset="SetFit/sst2",
use_grammar=options.use_grammar,
no_early_stopping=options.no_early_stopping,
evaluate_shortest_first=options.evaluate_shortest_first,
validation_split=f"validation[:{5 if options.debug else 200}]",
test_split="test[:20]" if options.debug else "test",
)
def predict(self, prompt: str, text: str):
# run model for inference using grammar to constrain output
# TODO grammar also depends on prompt and vice-versa -> what are good labels?
response, usage = self.model(
system_message=SYSTEM_MESSAGE,
prompt=prompt,
prompt_appendix="\nInput: " + '"' + text + '"',
grammar=sa_grammar_fn() if self.use_grammar else None,
)
if not self.use_grammar:
# we postprocess the model output to return as answer
response = response.strip()
# input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}")
return response, usage
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
return datum["text"]
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
sst2_labels = {"negative": 0, "positive": 1}
response, usage = self.predict(
prompt=prompt, text=self._get_text_for_datum(datum)
)
response = response.lower()
if self.use_grammar:
# model output is from label space
answer_label = sst2_labels[response]
else:
answer_label = None
for label in sst2_labels.keys():
if label in response:
answer_label = sst2_labels[label]
break
else:
logger.warning(f"Invalid answer: {response}")
return "failed", usage
classification_result = (
"incorrect" if answer_label != datum["label"] else "correct"
)
return classification_result, usage
def _aggregate_result(self, results: list[str]) -> float:
num_correct_results = sum(1 for result in results if result == "correct")
accuracy = num_correct_results / len(results)
return accuracy
@property
def metric_name(self):
return "accuracy"
@property
def base_prompt(self):
# from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text."""
def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'):
if len(sequence) == 0:
return ""
if quote is None:
quote = ""
escape_item = lambda text: text.replace('"', '\\"')
grammar = quote + escape_item(sequence[0]) + quote
for idx, item in enumerate(sequence[1:]):
new_item = "(" + quote + escape_item(item) + quote + ")?"
if idx == 0:
grammar += new_item
else:
grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :]
return grammar
# a grammar in GBNF notation which takes the context into account
# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class
@lru_cache
def extractive_qa_grammar_fn(context: str, verbose: bool = False):
grammar_str = "(\n"
tokens = re.split(r"(\W+)", context.strip())[:-1]
for i in range(len(tokens)):
if i > 0:
grammar_str += "\n|\n"
grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")"
grammar_str += "\n)"
return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose)
class QuestionAnswering(Task):
shorthand = "qa"
def __init__(self, model, options: Namespace):
self.metric = load_metric("squad")
super().__init__(
model,
"squad",
"squad",
use_grammar=options.use_grammar,
no_early_stopping=options.no_early_stopping,
evaluate_shortest_first=options.evaluate_shortest_first,
validation_split=f"train[:{5 if options.debug else 200}]",
test_split="validation[:20]" if options.debug else "validation",
)
def predict(self, prompt: str, datum: DatasetDatum):
# run model for inference
grammar = None
if self.use_grammar:
# context-sensitive grammar
context = datum["context"]
try:
grammar = extractive_qa_grammar_fn(context)
except Exception as e:
logger.exception(
"Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s",
len(context),
context,
exc_info=e,
)
response, usage = self.model(
system_message=SYSTEM_MESSAGE,
prompt=prompt,
prompt_appendix=self._get_text_for_datum(datum),
grammar=grammar,
)
if not self.use_grammar:
# we postprocess the model output to return as answer
response = response.strip()
return response, usage
def _get_text_for_datum(self, datum: DatasetDatum) -> str:
return (
"\nContext: "
+ '"'
+ datum["context"]
+ '"'
+ "\nQuestion: "
+ '"'
+ datum["question"]
+ '"'
)
def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
answer, usage = self.predict(prompt, datum)
# input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}')
# TODO check if answer is lower-cased in metric computation
result = self.metric.compute(
predictions=[{"prediction_text": answer, "id": datum["id"]}],
references=[{"answers": datum["answers"], "id": datum["id"]}],
)
return result["f1"] / 100, usage
def _aggregate_result(self, results: list[float]) -> float:
return sum(results) / len(results)
def evaluate(
self,
prompt: str,
dataset: Dataset,
parent_histories: list[list[float]] | None = None,
):
def replace_symbol_for_grammar(sample: DatasetDatum):
symbol_replacement_mapping = {
"\u2013": "-",
"\u2014": "-",
}
symbol_replacement_mapping = dict(
(re.escape(k), v) for k, v in symbol_replacement_mapping.items()
)
symbol_replacement_pattern = re.compile(
"|".join(symbol_replacement_mapping.keys())
)
replace_fn = lambda text: symbol_replacement_pattern.sub(
lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text
)
sample["context"] = replace_fn(sample["context"])
sample["answers"]["text"] = [
replace_fn(text) for text in sample["answers"]["text"]
]
return sample
if self.use_grammar:
# NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen)
dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols")
return super().evaluate(prompt, dataset, parent_histories=parent_histories)
@property
def metric_name(self):
return "f1"
@property
def base_prompt(self):
# TODO find good prompt
return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
tasks = {task.shorthand: task for task in Task.__subclasses__()}
def get_task(name: str, evaluation_model: LLMModel, options: Namespace):
if name not in tasks:
raise ValueError("Model %s does not exist", name)
return tasks[name](evaluation_model, options)
argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
argument_group = argument_parser.add_argument_group("Task arguments")
argument_group.add_argument("--use-grammar", "-g", action="store_true")
argument_group.add_argument("--no-early-stopping", action="store_true")
argument_group.add_argument("--evaluate-shortest-first", action="store_true")
argument_group.add_argument(
"--task", "-t", type=str, required=True, choices=tasks.keys()
)
......@@ -140,3 +140,8 @@ class log_calls:
arguments[argument_name] = value
return arguments
def get_all_subclasses(cls):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
)
\ No newline at end of file
......@@ -72,7 +72,7 @@ if __name__ == "__main__":
case "openai":
evaluation_model = Llama(options)
task = get_task(options.task, evaluation_model, options)
task = get_task(options.task, evaluation_model, **options.__dict__)
logger.info(
f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}"
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment