From c27d342bcd0f919e50f51870bfc8c0638a520671 Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Mon, 29 Jul 2024 17:53:30 +0200 Subject: [PATCH] Refactor tasks --- evoprompt/evolution.py | 2 +- evoprompt/task.py | 487 --------------------------- evoprompt/task/__init__.py | 33 ++ evoprompt/task/question_answering.py | 160 +++++++++ evoprompt/task/sentiment_analysis.py | 97 ++++++ evoprompt/task/task.py | 245 ++++++++++++++ evoprompt/utils.py | 5 + main.py | 2 +- 8 files changed, 542 insertions(+), 489 deletions(-) delete mode 100644 evoprompt/task.py create mode 100644 evoprompt/task/__init__.py create mode 100644 evoprompt/task/question_answering.py create mode 100644 evoprompt/task/sentiment_analysis.py create mode 100644 evoprompt/task/task.py diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index 9026d4d..e96e0ee 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -10,7 +10,7 @@ from evoprompt.models import LLMModel from evoprompt.opt_types import ModelUsage, Prompt from evoprompt.optimization import PromptOptimization from evoprompt.task import Task -from evoprompt.utils import log_calls +from evoprompt.utils import log_calls, get_all_subclasses logger = logging.getLogger(__name__) diff --git a/evoprompt/task.py b/evoprompt/task.py deleted file mode 100644 index d7367b5..0000000 --- a/evoprompt/task.py +++ /dev/null @@ -1,487 +0,0 @@ -import logging -import re -from abc import abstractmethod -from argparse import Namespace -from functools import lru_cache -from statistics import mean -from typing import Union - -from datasets import Dataset, load_dataset -from evaluate import load as load_metric -from llama_cpp import LlamaGrammar, deque -from tqdm import tqdm - -from evoprompt.cli import argument_parser -from evoprompt.models import LLMModel -from evoprompt.opt_types import ModelUsage -from evoprompt.utils import log_calls - -logger = logging.getLogger(__name__) - - -SYSTEM_MESSAGE = """ -You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. -""" - -DatasetDatum = dict - - -class EarlyStoppingMonitor: - - @abstractmethod - def update(self, score: float) -> bool: - raise NotImplementedError - - -class MomentBasedStopping(EarlyStoppingMonitor): - """ - Watch the first derivative (moment) of the metric to determine when to stop. - """ - - def __init__( - self, - *, - patience: int = 10, - start_after: int = 20, - min_moment_magnitude: float = 0.001, - ): - self.patience = patience - self.start_after = start_after - self.min_moment_magnitude = min_moment_magnitude - - self.moment_magnitudes = deque(maxlen=patience) - self.last_score = 0.0 - self.num_calls = 0 - - def update(self, score: float) -> bool: - # caclulate the current moment (dx/dt) - self.num_calls += 1 - if self.num_calls < self.start_after: - return False - - self.moment_magnitudes.append(abs(score - self.last_score)) - self.last_score = score - if len(self.moment_magnitudes) < self.patience: - return False - - if mean(self.moment_magnitudes) < self.min_moment_magnitude: - return True - - return False - - -class ParentBaselineBasedStopping(EarlyStoppingMonitor): - - def __init__( - self, - parent_histories: list[list[float]], - *, - patience: int = 10, - start_after: int = 20, - min_improvement: float = 0.001, - ): - self.parent_histories = parent_histories - self.patience = patience - self.start_after = start_after - self.min_improvement = min_improvement - self.num_calls = 0 - self.improvement_memory = deque(maxlen=patience) - - def update(self, score: float) -> bool: - self.num_calls += 1 - if self.num_calls < self.start_after: - return False - - parent_values = [ # get the metric value of the parents at the current step - ( - parent_history[self.num_calls - 1] - if len(parent_history) >= self.num_calls - else parent_history[-1] # extend with last value - ) - for parent_history in self.parent_histories - ] - self.improvement_memory.append( - score - max(parent_values) # compare with the best parent - ) - - if len(self.improvement_memory) < self.patience: - return False - - if max(self.improvement_memory) < self.min_improvement: - # if the highest improvement is less than the minimum improvement, we stop - return True - - return False - - -class Task: - shorthand: str - validation_dataset: Dataset - test_dataset: Dataset - - def __init__( - self, - model: Union[LLMModel], - validation_dataset: str, - test_dataset: str, - *, - use_grammar: bool, - no_early_stopping: bool = False, - evaluate_shortest_first: bool = False, - validation_split: str | None = None, - test_split: str | None = None, - ) -> None: - self.model = model - - # whether we use the grammar to constrain the model output or not - self.use_grammar = use_grammar - self.no_early_stopping = no_early_stopping - self.evaluate_shortest_first = evaluate_shortest_first - if evaluate_shortest_first and no_early_stopping: - logger.warning( - "Both 'evaluate_shortest_first' and 'no_early_stopping' are set. This makes no sense" - ) - self.validation_dataset = load_dataset( - validation_dataset, split=validation_split - ) - self.test_dataset = load_dataset(test_dataset, split=test_split) - - @abstractmethod - def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]: - pass - - @abstractmethod - def _evaluate_sample( - self, prompt: str, datum: DatasetDatum - ) -> tuple[str, ModelUsage]: - pass - - @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - pass - - @abstractmethod - def _aggregate_result(self, results: list) -> float: - pass - - def evaluate( - self, - prompt: str, - dataset: Dataset, - parent_histories: list[list[float]] | None = None, - ) -> tuple[float, ModelUsage, list[float]]: - - early_stopping: EarlyStoppingMonitor - early_stopping_params = { - "patience": max(len(dataset) // 20, 5), - "start_after": max(len(dataset) // 5, 5), - } - if parent_histories is not None: - early_stopping = ParentBaselineBasedStopping( - parent_histories, **early_stopping_params - ) - else: - early_stopping = MomentBasedStopping(**early_stopping_params) - - results: list = [] - if self.evaluate_shortest_first: - dataset = sorted(dataset, key=lambda x: len(self._get_text_for_datum(x))) - dataset_iterator: tqdm[DatasetDatum] = tqdm( - dataset, desc="evaluating prompt", leave=False - ) - evaluation_usage = ModelUsage() - evaluation_history = [] - - for datum in dataset_iterator: - result, usage = self._evaluate_sample(prompt, datum) - results.append(result) - current_metric = self._aggregate_result(results) - dataset_iterator.set_postfix( - {self.metric_name: f"{current_metric*100:.1f}%"} - ) - evaluation_usage += usage - evaluation_history.append(current_metric) - if not self.no_early_stopping and early_stopping.update(current_metric): - logger.info( - f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%" - ) - break - # input(f'F1 score: {current_metric:.2f}') - - return self._aggregate_result(results), evaluation_usage, evaluation_history - - @log_calls("Evaluating validation dataset") - def evaluate_validation( - self, prompt: str, parent_histories: list[list[float]] | None = None - ): - return self.evaluate(prompt, self.validation_dataset, parent_histories) - - @log_calls("Evaluating test dataset") - def evaluate_test(self, prompt: str): - return self.evaluate(prompt, self.test_dataset) - - @property - @abstractmethod - def metric_name(self) -> str: - pass - - @property - @abstractmethod - def base_prompt(self) -> str: - pass - - -# a simple grammar in GBNF notation which allows either positive or negative -# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class -SA_OUTPUTS = set(("positive", "negative")) - - -@lru_cache -def sa_grammar_fn(verbose: bool = False): - return LlamaGrammar.from_string( - "root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)), - verbose=verbose, - ) - - -class SentimentAnalysis(Task): - shorthand = "sa" - - def __init__(self, model, options: Namespace): - super().__init__( - model, - validation_dataset="SetFit/sst2", - test_dataset="SetFit/sst2", - use_grammar=options.use_grammar, - no_early_stopping=options.no_early_stopping, - evaluate_shortest_first=options.evaluate_shortest_first, - validation_split=f"validation[:{5 if options.debug else 200}]", - test_split="test[:20]" if options.debug else "test", - ) - - def predict(self, prompt: str, text: str): - # run model for inference using grammar to constrain output - # TODO grammar also depends on prompt and vice-versa -> what are good labels? - response, usage = self.model( - system_message=SYSTEM_MESSAGE, - prompt=prompt, - prompt_appendix="\nInput: " + '"' + text + '"', - grammar=sa_grammar_fn() if self.use_grammar else None, - ) - - if not self.use_grammar: - # we postprocess the model output to return as answer - response = response.strip() - # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}") - - return response, usage - - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - return datum["text"] - - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - sst2_labels = {"negative": 0, "positive": 1} - - response, usage = self.predict( - prompt=prompt, text=self._get_text_for_datum(datum) - ) - response = response.lower() - if self.use_grammar: - # model output is from label space - answer_label = sst2_labels[response] - else: - answer_label = None - for label in sst2_labels.keys(): - if label in response: - answer_label = sst2_labels[label] - break - else: - logger.warning(f"Invalid answer: {response}") - return "failed", usage - - classification_result = ( - "incorrect" if answer_label != datum["label"] else "correct" - ) - return classification_result, usage - - def _aggregate_result(self, results: list[str]) -> float: - num_correct_results = sum(1 for result in results if result == "correct") - accuracy = num_correct_results / len(results) - return accuracy - - @property - def metric_name(self): - return "accuracy" - - @property - def base_prompt(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text.""" - - -def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'): - if len(sequence) == 0: - return "" - - if quote is None: - quote = "" - - escape_item = lambda text: text.replace('"', '\\"') - - grammar = quote + escape_item(sequence[0]) + quote - for idx, item in enumerate(sequence[1:]): - new_item = "(" + quote + escape_item(item) + quote + ")?" - if idx == 0: - grammar += new_item - else: - grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :] - return grammar - - -# a grammar in GBNF notation which takes the context into account -# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class -@lru_cache -def extractive_qa_grammar_fn(context: str, verbose: bool = False): - grammar_str = "(\n" - tokens = re.split(r"(\W+)", context.strip())[:-1] - - for i in range(len(tokens)): - if i > 0: - grammar_str += "\n|\n" - grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")" - grammar_str += "\n)" - return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose) - - -class QuestionAnswering(Task): - shorthand = "qa" - - def __init__(self, model, options: Namespace): - - self.metric = load_metric("squad") - - super().__init__( - model, - "squad", - "squad", - use_grammar=options.use_grammar, - no_early_stopping=options.no_early_stopping, - evaluate_shortest_first=options.evaluate_shortest_first, - validation_split=f"train[:{5 if options.debug else 200}]", - test_split="validation[:20]" if options.debug else "validation", - ) - - def predict(self, prompt: str, datum: DatasetDatum): - # run model for inference - grammar = None - if self.use_grammar: - # context-sensitive grammar - context = datum["context"] - try: - grammar = extractive_qa_grammar_fn(context) - except Exception as e: - logger.exception( - "Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s", - len(context), - context, - exc_info=e, - ) - - response, usage = self.model( - system_message=SYSTEM_MESSAGE, - prompt=prompt, - prompt_appendix=self._get_text_for_datum(datum), - grammar=grammar, - ) - - if not self.use_grammar: - # we postprocess the model output to return as answer - response = response.strip() - - return response, usage - - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - return ( - "\nContext: " - + '"' - + datum["context"] - + '"' - + "\nQuestion: " - + '"' - + datum["question"] - + '"' - ) - - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - answer, usage = self.predict(prompt, datum) - # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}') - # TODO check if answer is lower-cased in metric computation - - result = self.metric.compute( - predictions=[{"prediction_text": answer, "id": datum["id"]}], - references=[{"answers": datum["answers"], "id": datum["id"]}], - ) - - return result["f1"] / 100, usage - - def _aggregate_result(self, results: list[float]) -> float: - return sum(results) / len(results) - - def evaluate( - self, - prompt: str, - dataset: Dataset, - parent_histories: list[list[float]] | None = None, - ): - def replace_symbol_for_grammar(sample: DatasetDatum): - symbol_replacement_mapping = { - "\u2013": "-", - "\u2014": "-", - } - symbol_replacement_mapping = dict( - (re.escape(k), v) for k, v in symbol_replacement_mapping.items() - ) - symbol_replacement_pattern = re.compile( - "|".join(symbol_replacement_mapping.keys()) - ) - replace_fn = lambda text: symbol_replacement_pattern.sub( - lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text - ) - sample["context"] = replace_fn(sample["context"]) - sample["answers"]["text"] = [ - replace_fn(text) for text in sample["answers"]["text"] - ] - return sample - - if self.use_grammar: - # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen) - dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols") - return super().evaluate(prompt, dataset, parent_histories=parent_histories) - - @property - def metric_name(self): - return "f1" - - @property - def base_prompt(self): - # TODO find good prompt - return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" - - -tasks = {task.shorthand: task for task in Task.__subclasses__()} - - -def get_task(name: str, evaluation_model: LLMModel, options: Namespace): - if name not in tasks: - raise ValueError("Model %s does not exist", name) - return tasks[name](evaluation_model, options) - - -argument_parser.add_argument("--debug", "-d", action="store_true", default=None) -argument_group = argument_parser.add_argument_group("Task arguments") -argument_group.add_argument("--use-grammar", "-g", action="store_true") -argument_group.add_argument("--no-early-stopping", action="store_true") -argument_group.add_argument("--evaluate-shortest-first", action="store_true") -argument_group.add_argument( - "--task", "-t", type=str, required=True, choices=tasks.keys() -) diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py new file mode 100644 index 0000000..f810994 --- /dev/null +++ b/evoprompt/task/__init__.py @@ -0,0 +1,33 @@ +from argparse import Namespace + +from evoprompt.cli import argument_parser +from evoprompt.models import LLMModel +from evoprompt.task.question_answering import QuestionAnswering +from evoprompt.task.sentiment_analysis import SentimentAnalysis + +# make sure to run definitions of subclasses of Task first +from evoprompt.task.task import Task +from evoprompt.utils import get_all_subclasses + +# at this point, we assume that all subclasses of Task have been defined +tasks = { + task.shorthand: task + for task in get_all_subclasses(Task) + if hasattr(task, "shorthand") +} + + +def get_task(name: str, evaluation_model: LLMModel, **options): + if name not in tasks: + raise ValueError("Model %s does not exist", name) + return tasks[name](evaluation_model, **options) + + +argument_parser.add_argument("--debug", "-d", action="store_true", default=None) +argument_group = argument_parser.add_argument_group("Task arguments") +argument_group.add_argument("--use-grammar", "-g", action="store_true") +argument_group.add_argument("--no-early-stopping", action="store_true") +argument_group.add_argument("--evaluate-shortest-first", action="store_true") +argument_group.add_argument( + "--task", "-t", type=str, required=True, choices=tasks.keys() +) diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py new file mode 100644 index 0000000..25c9aac --- /dev/null +++ b/evoprompt/task/question_answering.py @@ -0,0 +1,160 @@ +import logging +import re +from argparse import Namespace +from functools import lru_cache + +from datasets import Dataset +from evaluate import load as load_metric +from llama_cpp import LlamaGrammar + +from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task + +logger = logging.getLogger(__name__) + + +def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'): + if len(sequence) == 0: + return "" + + if quote is None: + quote = "" + + escape_item = lambda text: text.replace('"', '\\"') + + grammar = quote + escape_item(sequence[0]) + quote + for idx, item in enumerate(sequence[1:]): + new_item = "(" + quote + escape_item(item) + quote + ")?" + if idx == 0: + grammar += new_item + else: + grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :] + return grammar + + +# a grammar in GBNF notation which takes the context into account +# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class +@lru_cache +def extractive_qa_grammar_fn(context: str, verbose: bool = False): + grammar_str = "(\n" + tokens = re.split(r"(\W+)", context.strip())[:-1] + + for i in range(len(tokens)): + if i > 0: + grammar_str += "\n|\n" + grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")" + grammar_str += "\n)" + return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose) + + +class QuestionAnswering(Task): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.metric = load_metric("squad") + + def predict(self, prompt: str, question: str, context: str): + # run model for inference + grammar = None + if self.use_grammar: + # context-sensitive grammar + context = context + try: + grammar = extractive_qa_grammar_fn(context) + except Exception as e: + logger.exception( + "Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s", + len(context), + context, + exc_info=e, + ) + + response, usage = self.model( + system_message=SYSTEM_MESSAGE, + prompt=prompt, + prompt_appendix=self._get_text_for_datum(question, context), + grammar=grammar, + ) + + if not self.use_grammar: + # we postprocess the model output to return as answer + response = response.strip() + + return response, usage + + def _get_text_for_datum(self, question: str, context: str) -> str: + return ( + "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"' + ) + + def _evaluate_qa_sample( + self, prompt: str, id: str, question, context, gold_answers: list[str] + ): + answer, usage = self.predict(prompt, question, context) + # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}') + # TODO check if answer is lower-cased in metric computation + + result = self.metric.compute( + predictions=[{"prediction_text": answer, "id": id}], + references=[{"answers": gold_answers, "id": id}], + ) + return result["f1"] / 100, usage + + def _aggregate_result(self, results: list[float]) -> float: + return sum(results) / len(results) + + def evaluate( + self, + prompt: str, + dataset: Dataset, + parent_histories: list[list[float]] | None = None, + ): + def replace_symbol_for_grammar(sample: DatasetDatum): + symbol_replacement_mapping = { + "\u2013": "-", + "\u2014": "-", + } + symbol_replacement_mapping = dict( + (re.escape(k), v) for k, v in symbol_replacement_mapping.items() + ) + symbol_replacement_pattern = re.compile( + "|".join(symbol_replacement_mapping.keys()) + ) + replace_fn = lambda text: symbol_replacement_pattern.sub( + lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text + ) + sample["context"] = replace_fn(sample["context"]) + sample["answers"]["text"] = [ + replace_fn(text) for text in sample["answers"]["text"] + ] + return sample + + if self.use_grammar: + # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen) + dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols") + return super().evaluate(prompt, dataset, parent_histories=parent_histories) + + @property + def metric_name(self): + return "f1" + + @property + def base_prompt(self): + # TODO find good prompt + return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + + +class SQuAD(QuestionAnswering): + shorthand = "squad" + + def load_validation_set( + self, validation_dataset: str | None, validation_split: str | None + ): + return super().load_validation_set("squad", "train[:200]") + + def load_test_set(self, test_dataset: str, test_split: str | None): + return super().load_test_set("squad", "validation") + + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + return self._evaluate_qa_sample( + prompt, datum["id"], datum["question"], datum["context"], datum["answers"] + ) diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py new file mode 100644 index 0000000..a31a4fc --- /dev/null +++ b/evoprompt/task/sentiment_analysis.py @@ -0,0 +1,97 @@ +import logging +from argparse import Namespace +from functools import lru_cache + +from llama_cpp import LlamaGrammar + +from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task + +logger = logging.getLogger(__name__) + + +# a simple grammar in GBNF notation which allows either positive or negative +# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class +SA_OUTPUTS = set(("positive", "negative")) + + +@lru_cache +def sa_grammar_fn(verbose: bool = False): + return LlamaGrammar.from_string( + "root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)), + verbose=verbose, + ) + + +class SentimentAnalysis(Task): + shorthand = "sa" + + def predict(self, prompt: str, text: str): + # run model for inference using grammar to constrain output + # TODO grammar also depends on prompt and vice-versa -> what are good labels? + response, usage = self.model( + system_message=SYSTEM_MESSAGE, + prompt=prompt, + prompt_appendix="\nInput: " + '"' + text + '"', + grammar=sa_grammar_fn() if self.use_grammar else None, + ) + + if not self.use_grammar: + # we postprocess the model output to return as answer + response = response.strip() + # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}") + + return response, usage + + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + return datum["text"] + + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + sst2_labels = {"negative": 0, "positive": 1} + + response, usage = self.predict( + prompt=prompt, text=self._get_text_for_datum(datum) + ) + response = response.lower() + if self.use_grammar: + # model output is from label space + answer_label = sst2_labels[response] + else: + answer_label = None + for label in sst2_labels.keys(): + if label in response: + answer_label = sst2_labels[label] + break + else: + logger.warning(f"Invalid answer: {response}") + return "failed", usage + + classification_result = ( + "incorrect" if answer_label != datum["label"] else "correct" + ) + return classification_result, usage + + def _aggregate_result(self, results: list[str]) -> float: + num_correct_results = sum(1 for result in results if result == "correct") + accuracy = num_correct_results / len(results) + return accuracy + + @property + def metric_name(self): + return "accuracy" + + @property + def base_prompt(self): + # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning + return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text.""" + + +class SST2(SentimentAnalysis): + shorthand = "sa" + + def load_validation_set( + self, validation_dataset: str | None, validation_split: str | None + ): + return super().load_validation_set("SetFit/sst2", "validation[:200]") + + def load_test_set(self, test_dataset: str, test_split: str | None): + return super().load_test_set("SetFit/sst2", "test") diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py new file mode 100644 index 0000000..40c3bd0 --- /dev/null +++ b/evoprompt/task/task.py @@ -0,0 +1,245 @@ +import logging +import re +from abc import ABCMeta, abstractmethod +from argparse import Namespace +from functools import lru_cache +from statistics import mean +from typing import Union + +from datasets import Dataset, load_dataset +from llama_cpp import LlamaGrammar, deque +from tqdm import tqdm + +from evoprompt.models import LLMModel +from evoprompt.opt_types import ModelUsage +from evoprompt.utils import log_calls + +logger = logging.getLogger(__name__) + + +SYSTEM_MESSAGE = """ +You are given an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. +""" + +DatasetDatum = dict + + +class EarlyStoppingMonitor: + + @abstractmethod + def update(self, score: float) -> bool: + raise NotImplementedError + + +class MomentBasedStopping(EarlyStoppingMonitor): + """ + Watch the first derivative (moment) of the metric to determine when to stop. + """ + + def __init__( + self, + *, + patience: int = 10, + start_after: int = 20, + min_moment_magnitude: float = 0.001, + ): + self.patience = patience + self.start_after = start_after + self.min_moment_magnitude = min_moment_magnitude + + self.moment_magnitudes = deque(maxlen=patience) + self.last_score = 0.0 + self.num_calls = 0 + + def update(self, score: float) -> bool: + # caclulate the current moment (dx/dt) + self.num_calls += 1 + if self.num_calls < self.start_after: + return False + + self.moment_magnitudes.append(abs(score - self.last_score)) + self.last_score = score + if len(self.moment_magnitudes) < self.patience: + return False + + if mean(self.moment_magnitudes) < self.min_moment_magnitude: + return True + + return False + + +class ParentBaselineBasedStopping(EarlyStoppingMonitor): + + def __init__( + self, + parent_histories: list[list[float]], + *, + patience: int = 10, + start_after: int = 20, + min_improvement: float = 0.001, + ): + self.parent_histories = parent_histories + self.patience = patience + self.start_after = start_after + self.min_improvement = min_improvement + self.num_calls = 0 + self.improvement_memory = deque(maxlen=patience) + + def update(self, score: float) -> bool: + self.num_calls += 1 + if self.num_calls < self.start_after: + return False + + parent_values = [ # get the metric value of the parents at the current step + ( + parent_history[self.num_calls - 1] + if len(parent_history) >= self.num_calls + else parent_history[-1] # extend with last value + ) + for parent_history in self.parent_histories + ] + self.improvement_memory.append( + score - max(parent_values) # compare with the best parent + ) + + if len(self.improvement_memory) < self.patience: + return False + + if max(self.improvement_memory) < self.min_improvement: + # if the highest improvement is less than the minimum improvement, we stop + return True + + return False + + +class Task(metaclass=ABCMeta): + shorthand: str + validation_dataset: Dataset + test_dataset: Dataset + + def __init__( + self, + model: Union[LLMModel], + validation_dataset: str | None = None, + test_dataset: str | None = None, + *, + use_grammar: bool, + no_early_stopping: bool = False, + evaluate_shortest_first: bool = False, + validation_split: str | None = None, + test_split: str | None = None, + debug: bool = False, + **kwargs, + ) -> None: + self.model = model + self.debug = debug + # whether we use the grammar to constrain the model output or not + self.use_grammar = use_grammar + self.no_early_stopping = no_early_stopping + self.evaluate_shortest_first = evaluate_shortest_first + if evaluate_shortest_first and no_early_stopping: + logger.warning( + "Both 'evaluate_shortest_first' and 'no_early_stopping' are set. This makes no sense" + ) + self.validation_dataset = self.load_validation_set( + validation_dataset, validation_split + ) + self.test_dataset = self.load_test_set(test_dataset, test_split) + + def load_validation_set( + self, validation_dataset: str, validation_split: str | None + ): + dataset = load_dataset(validation_dataset, split=validation_split) + if self.debug and len(dataset) > 5: + dataset = dataset.select(range(5)) + return dataset + + def load_test_set(self, test_dataset: str, test_split: str | None): + dataset = load_dataset(test_dataset, split=test_split) + if self.debug and len(dataset) > 20: + dataset = dataset.select(range(20)) + return dataset + + @abstractmethod + def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]: + pass + + @abstractmethod + def _evaluate_sample( + self, prompt: str, datum: DatasetDatum + ) -> tuple[str, ModelUsage]: + pass + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + pass + + @abstractmethod + def _aggregate_result(self, results: list) -> float: + pass + + def evaluate( + self, + prompt: str, + dataset: Dataset, + parent_histories: list[list[float]] | None = None, + ) -> tuple[float, ModelUsage, list[float]]: + + early_stopping: EarlyStoppingMonitor + early_stopping_params = { + "patience": max(len(dataset) // 20, 5), + "start_after": max(len(dataset) // 5, 5), + } + if parent_histories is not None: + early_stopping = ParentBaselineBasedStopping( + parent_histories, **early_stopping_params + ) + else: + early_stopping = MomentBasedStopping(**early_stopping_params) + + results: list = [] + if self.evaluate_shortest_first: + dataset = sorted(dataset, key=lambda x: len(self._get_text_for_datum(x))) + dataset_iterator: tqdm[DatasetDatum] = tqdm( + dataset, desc="evaluating prompt", leave=False + ) + evaluation_usage = ModelUsage() + evaluation_history = [] + + for datum in dataset_iterator: + result, usage = self._evaluate_sample(prompt, datum) + results.append(result) + current_metric = self._aggregate_result(results) + dataset_iterator.set_postfix( + {self.metric_name: f"{current_metric*100:.1f}%"} + ) + evaluation_usage += usage + evaluation_history.append(current_metric) + if not self.no_early_stopping and early_stopping.update(current_metric): + logger.info( + f"Early stopping after {len(results)} samples with {self.metric_name} of {current_metric*100:.1f}%" + ) + break + # input(f'F1 score: {current_metric:.2f}') + + return self._aggregate_result(results), evaluation_usage, evaluation_history + + @log_calls("Evaluating validation dataset") + def evaluate_validation( + self, prompt: str, parent_histories: list[list[float]] | None = None + ): + return self.evaluate(prompt, self.validation_dataset, parent_histories) + + @log_calls("Evaluating test dataset") + def evaluate_test(self, prompt: str): + return self.evaluate(prompt, self.test_dataset) + + @property + @abstractmethod + def metric_name(self) -> str: + pass + + @property + @abstractmethod + def base_prompt(self) -> str: + pass diff --git a/evoprompt/utils.py b/evoprompt/utils.py index ac8e6fe..3eefa2e 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -140,3 +140,8 @@ class log_calls: arguments[argument_name] = value return arguments + +def get_all_subclasses(cls): + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] + ) \ No newline at end of file diff --git a/main.py b/main.py index 6382480..6a3d6b5 100644 --- a/main.py +++ b/main.py @@ -72,7 +72,7 @@ if __name__ == "__main__": case "openai": evaluation_model = Llama(options) - task = get_task(options.task, evaluation_model, options) + task = get_task(options.task, evaluation_model, **options.__dict__) logger.info( f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}" ) -- GitLab