diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index 9026d4d9f02ec394fafc0fab5a7b89c54772e980..e96e0ee44734805a8dea0196b279c1f10ff8e149 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -10,7 +10,7 @@ from evoprompt.models import LLMModel from evoprompt.opt_types import ModelUsage, Prompt from evoprompt.optimization import PromptOptimization from evoprompt.task import Task -from evoprompt.utils import log_calls +from evoprompt.utils import log_calls, get_all_subclasses logger = logging.getLogger(__name__) diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 59b564a3109c6e8be8b914dcb94f3661cd5df205..7100a7a1f1630dadf3d24d2a3c6da03d7a9948e3 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -43,9 +43,19 @@ def paraphrase_prompts( total_usage += usage if "<prompt>" in paraphrase: paraphrase = paraphrase.split("<prompt>")[1].split("</prompt>")[0] - if not unique_paraphrases or paraphrase not in paraphrases: + if ( + not unique_paraphrases + or paraphrase not in paraphrases + or max_tries - num_tries == n - len(paraphrases) + ): # add paraphrase only if not already present if unique_paraphrases==True paraphrases.append(paraphrase) + + assert len(paraphrases) == n, "Requested %d paraphrases, but %d were generated." % ( + n, + len(paraphrases), + ) + if return_only_unique_paraphrases: paraphrases = list(set(paraphrases)) return paraphrases, total_usage @@ -156,7 +166,11 @@ class PromptOptimization: unique_paraphrases=True, ) self.total_evolution_usage += paraphrase_usage - logger.info("Paraphrased prompt '%s': %s.", self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), paraphrases) + logger.info( + "Paraphrased prompt '%s': %s.", + self.task.base_prompt.replace("\r", "\\r").replace("\n", "\\n"), + paraphrases, + ) # the initial prompts initial_prompts = [self.task.base_prompt] + paraphrases diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64bbb9fdf4af632f4a00eea2b313619ad7bec444 --- /dev/null +++ b/evoprompt/task/__init__.py @@ -0,0 +1,37 @@ +from argparse import Namespace +from typing import Literal + +from evoprompt.cli import argument_parser +from evoprompt.models import LLMModel +from evoprompt.task.question_answering import QuestionAnswering +from evoprompt.task.sentiment_analysis import SentimentAnalysis + +# make sure to run definitions of subclasses of Task first +from evoprompt.task.task import EvaluationStrategyKey, Task +from evoprompt.utils import get_all_subclasses + +# at this point, we assume that all subclasses of Task have been defined +tasks = { + task.shorthand: task + for task in get_all_subclasses(Task) + if hasattr(task, "shorthand") +} + + +def get_task(name: str, evaluation_model: LLMModel, **options): + if name not in tasks: + raise ValueError("Model %s does not exist", name) + return tasks[name](evaluation_model, **options) + + +argument_parser.add_argument("--debug", "-d", action="store_true", default=None) +argument_group = argument_parser.add_argument_group("Task arguments") +argument_group.add_argument("--use-grammar", "-g", action="store_true") +argument_group.add_argument( + "--evaluation-strategy", + choices=EvaluationStrategyKey.__args__, + default="simple", +) +argument_group.add_argument( + "--task", "-t", type=str, required=True, choices=tasks.keys() +) diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..84a53e57f2c28b9c7f804f77f8ef9a99f37d5176 --- /dev/null +++ b/evoprompt/task/question_answering.py @@ -0,0 +1,187 @@ +import logging +import re +from abc import abstractmethod +from functools import lru_cache + +from datasets import Dataset +from evaluate import load as load_metric +from llama_cpp import LlamaGrammar + +from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task + +logger = logging.getLogger(__name__) + + +def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'): + if len(sequence) == 0: + return "" + + if quote is None: + quote = "" + + escape_item = lambda text: text.replace('"', '\\"') + + grammar = quote + escape_item(sequence[0]) + quote + for idx, item in enumerate(sequence[1:]): + new_item = "(" + quote + escape_item(item) + quote + ")?" + if idx == 0: + grammar += new_item + else: + grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :] + return grammar + + +# a grammar in GBNF notation which takes the context into account +# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class +@lru_cache +def extractive_qa_grammar_fn(context: str, verbose: bool = False): + grammar_str = "(\n" + tokens = re.split(r"(\W+)", context.strip())[:-1] + + for i in range(len(tokens)): + if i > 0: + grammar_str += "\n|\n" + grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")" + grammar_str += "\n)" + return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose) + + +class QuestionAnswering(Task): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.metric = load_metric("squad") + + def predict(self, prompt: str, datum: DatasetDatum): + # run model for inference + grammar = None + if self.use_grammar: + # context-sensitive grammar + context = self._get_context_from_datum(datum) + try: + grammar = extractive_qa_grammar_fn(context) + except Exception as e: + logger.exception( + "Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s", + len(context), + context, + exc_info=e, + ) + + response, usage = self.model( + system_message=SYSTEM_MESSAGE, + prompt=prompt, + prompt_appendix=self._get_prompt_text_for_datum(datum), + grammar=grammar, + ) + + if not self.use_grammar: + # we postprocess the model output to return as answer + response = response.strip() + + return response, usage + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + context = self._get_context_from_datum(datum) + question = self._get_question_from_datum(datum) + return ( + "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"' + ) + + @abstractmethod + def _get_id_from_datum(self, datum: DatasetDatum): + pass + + @abstractmethod + def _get_context_from_datum(self, datum: DatasetDatum): + pass + + @abstractmethod + def _get_question_from_datum(self, datum: DatasetDatum): + pass + + @abstractmethod + def _get_gold_labels_from_datum(self, datum: DatasetDatum): + pass + + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + _id = self._get_id_from_datum(datum) + context = self._get_context_from_datum(datum) + question = self._get_question_from_datum(datum) + gold_answers = self._get_gold_labels_from_datum(datum) + answer, usage = self.predict(prompt, datum) + # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}') + # TODO check if answer is lower-cased in metric computation + + result = self.metric.compute( + predictions=[{"prediction_text": answer, "id": _id}], + references=[{"answers": gold_answers, "id": _id}], + ) + return result["f1"] / 100, usage + + def _aggregate_result(self, results: list[float]) -> float: + return sum(results) / len(results) + + def evaluate( + self, + prompt: str, + dataset: Dataset, + parent_histories: list[list[float]] | None = None, + ): + def replace_symbol_for_grammar(sample: DatasetDatum): + symbol_replacement_mapping = { + "\u2013": "-", + "\u2014": "-", + } + symbol_replacement_mapping = dict( + (re.escape(k), v) for k, v in symbol_replacement_mapping.items() + ) + symbol_replacement_pattern = re.compile( + "|".join(symbol_replacement_mapping.keys()) + ) + replace_fn = lambda text: symbol_replacement_pattern.sub( + lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text + ) + sample["context"] = replace_fn(sample["context"]) + sample["answers"]["text"] = [ + replace_fn(text) for text in sample["answers"]["text"] + ] + return sample + + if self.use_grammar: + # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen) + dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols") + return super().evaluate(prompt, dataset, parent_histories=parent_histories) + + @property + def metric_name(self): + return "f1" + + @property + def base_prompt(self): + # TODO find good prompt + return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" + + +class SQuAD(QuestionAnswering): + shorthand = "squad" + + def load_validation_set( + self, validation_dataset: str | None, validation_split: str | None + ): + return super().load_validation_set("squad", "train[:200]") + + def load_test_set(self, test_dataset: str, test_split: str | None): + return super().load_test_set("squad", "validation") + + def _get_context_from_datum(self, datum: DatasetDatum): + return datum["context"] + + def _get_question_from_datum(self, datum: DatasetDatum): + return datum["question"] + + def _get_id_from_datum(self, datum: DatasetDatum): + return datum["id"] + + def _get_gold_labels_from_datum(self, datum: DatasetDatum): + return datum["answers"] diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..628cbc57b3a08d5fa33b289ff892cad4674439c7 --- /dev/null +++ b/evoprompt/task/sentiment_analysis.py @@ -0,0 +1,106 @@ +import logging +from abc import abstractmethod +from argparse import Namespace +from functools import lru_cache + +from llama_cpp import LlamaGrammar + +from evoprompt.task.task import SYSTEM_MESSAGE, DatasetDatum, Task + +logger = logging.getLogger(__name__) + + +# a simple grammar in GBNF notation which allows either positive or negative +# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class +SA_OUTPUTS = set(("positive", "negative")) + + +@lru_cache +def sa_grammar_fn(verbose: bool = False): + return LlamaGrammar.from_string( + "root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)), + verbose=verbose, + ) + + +class SentimentAnalysis(Task): + def predict(self, prompt: str, datum: DatasetDatum): + # run model for inference using grammar to constrain output + # TODO grammar also depends on prompt and vice-versa -> what are good labels? + response, usage = self.model( + system_message=SYSTEM_MESSAGE, + prompt=prompt, + prompt_appendix=self._get_prompt_text_for_datum(datum), + grammar=sa_grammar_fn() if self.use_grammar else None, + ) + + if not self.use_grammar: + # we postprocess the model output to return as answer + response = response.strip() + # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}") + + return response, usage + + def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + gold_label = self._get_gold_label_for_datum(datum) + sst2_labels = {"negative": 0, "positive": 1} + response, usage = self.predict(prompt=prompt, datum=datum) + response = response.lower() + if self.use_grammar: + # model output is from label space + answer_label = sst2_labels[response] + else: + answer_label = None + for label in sst2_labels.keys(): + if label in response: + answer_label = sst2_labels[label] + break + else: + logger.warning(f"Invalid answer: {response}") + return "failed", usage + + classification_result = "incorrect" if answer_label != gold_label else "correct" + return classification_result, usage + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + pass + + @abstractmethod + def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + pass + + def _aggregate_result(self, results: list[str]) -> float: + num_correct_results = sum(1 for result in results if result == "correct") + accuracy = num_correct_results / len(results) + return accuracy + + @property + def metric_name(self): + return "accuracy" + + @property + def base_prompt(self): + # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning + return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text.""" + + +class SST2(SentimentAnalysis): + shorthand = "sst2" + + def load_validation_set( + self, validation_dataset: str | None, validation_split: str | None + ): + return super().load_validation_set("stanfordnlp/sst2", "validation[:200]") + + def load_test_set(self, test_dataset: str, test_split: str | None): + return super().load_test_set("stanfordnlp/sst2", "test") + + def _get_text_for_datum(self, datum: DatasetDatum) -> str: + return datum["sentence"] + + def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: + return datum["label"] diff --git a/evoprompt/task.py b/evoprompt/task/task.py similarity index 50% rename from evoprompt/task.py rename to evoprompt/task/task.py index 3743a7db7fd1a7ace72ad1507cac660df8ba7f03..9914c93a8d1d4e3c1600cffa81d4d9496b2e6e50 100644 --- a/evoprompt/task.py +++ b/evoprompt/task/task.py @@ -1,18 +1,13 @@ import logging -import re -from abc import abstractmethod -from argparse import Namespace +from abc import ABCMeta, abstractmethod +from collections import deque from dataclasses import KW_ONLY, dataclass -from functools import lru_cache from statistics import mean from typing import Iterable, Literal, Union from datasets import Dataset, load_dataset -from evaluate import load as load_metric -from llama_cpp import LlamaGrammar, deque from tqdm import tqdm -from evoprompt.cli import argument_parser from evoprompt.models import LLMModel from evoprompt.opt_types import ModelUsage from evoprompt.utils import log_calls @@ -26,6 +21,7 @@ You are given an instruction that describes a task, paired with an input that pr DatasetDatum = dict + ParentHistories = list[list[float]] @@ -132,12 +128,12 @@ class ParentBaselineBasedStopping(EarlyStoppingMonitor): return False -EvaluationStrategyKeys = Literal[ +EvaluationStrategyKey = Literal[ "simple", "early-stopping", "shortest-first", "hardest-first" ] -def get_evaluation_strategy(evaluation_strategy_key: EvaluationStrategyKeys): +def get_evaluation_strategy(evaluation_strategy_key: EvaluationStrategyKey): match evaluation_strategy_key: case "simple": return SimpleStrategy @@ -199,7 +195,7 @@ class ShortestFirstStrategy(EarlyStoppingStrategy): self, dataset: Dataset, parent_histories: ParentHistories | None ): sorted_dataset = sorted( - dataset, key=lambda x: len(self.task.get_text_for_datum(x)) + dataset, key=lambda x: len(self.task._get_prompt_text_for_datum(x)) ) return super().get_dataset_iterator(sorted_dataset, parent_histories) @@ -251,7 +247,7 @@ class HardestFirstStrategy(ShortestFirstStrategy): return self.early_stopping.update(score) -class Task: +class Task(metaclass=ABCMeta): shorthand: str validation_dataset: Dataset test_dataset: Dataset @@ -259,23 +255,39 @@ class Task: def __init__( self, model: Union[LLMModel], - validation_dataset: str, - test_dataset: str, + validation_dataset: str | None = None, + test_dataset: str | None = None, *, use_grammar: bool, - evaluation_strategy_cls: type[EvaluationStrategy], + evaluation_strategy: EvaluationStrategyKey, validation_split: str | None = None, test_split: str | None = None, + debug: bool = False, + **kwargs, ) -> None: self.model = model - + self.debug = debug # whether we use the grammar to constrain the model output or not self.use_grammar = use_grammar - self.evaluation_strategy = evaluation_strategy_cls(self) - self.validation_dataset = load_dataset( - validation_dataset, split=validation_split + self.evaluation_strategy = get_evaluation_strategy(evaluation_strategy)(self) + self.validation_dataset = self.load_validation_set( + validation_dataset, validation_split ) - self.test_dataset = load_dataset(test_dataset, split=test_split) + self.test_dataset = self.load_test_set(test_dataset, test_split) + + def load_validation_set( + self, validation_dataset: str, validation_split: str | None + ): + dataset = load_dataset(validation_dataset, split=validation_split) + if self.debug and len(dataset) > 5: + dataset = dataset.select(range(5)) + return dataset + + def load_test_set(self, test_dataset: str, test_split: str | None): + dataset = load_dataset(test_dataset, split=test_split) + if self.debug and len(dataset) > 20: + dataset = dataset.select(range(20)) + return dataset @abstractmethod def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]: @@ -288,7 +300,7 @@ class Task: pass @abstractmethod - def get_text_for_datum(self, datum: DatasetDatum) -> str: + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: pass @abstractmethod @@ -313,6 +325,7 @@ class Task: logger.info( f"using early stopping: {self.evaluation_strategy.early_stopping}", ) + results: list = [] evaluation_usage = ModelUsage() evaluation_history = [] @@ -356,264 +369,3 @@ class Task: @abstractmethod def base_prompt(self) -> str: pass - - -# a simple grammar in GBNF notation which allows either positive or negative -# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class -SA_OUTPUTS = set(("positive", "negative")) - - -@lru_cache -def sa_grammar_fn(verbose: bool = False): - return LlamaGrammar.from_string( - "root ::= ({})".format("|".join(('"' + cons + '"') for cons in SA_OUTPUTS)), - verbose=verbose, - ) - - -class SentimentAnalysis(Task): - shorthand = "sa" - - def __init__(self, model, options: Namespace): - super().__init__( - model, - validation_dataset="SetFit/sst2", - test_dataset="SetFit/sst2", - use_grammar=options.use_grammar, - evaluation_strategy_cls=get_evaluation_strategy( - options.evaluation_strategy - ), - validation_split=f"validation[:{5 if options.debug else 200}]", - test_split="test[:20]" if options.debug else "test", - ) - - def predict(self, prompt: str, text: str): - # run model for inference using grammar to constrain output - # TODO grammar also depends on prompt and vice-versa -> what are good labels? - response, usage = self.model( - system_message=SYSTEM_MESSAGE, - prompt=prompt, - prompt_appendix="\nInput: " + '"' + text + '"', - grammar=sa_grammar_fn() if self.use_grammar else None, - ) - - if not self.use_grammar: - # we postprocess the model output to return as answer - response = response.strip() - # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}") - - return response, usage - - def get_text_for_datum(self, datum: DatasetDatum) -> str: - return datum["text"] - - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - sst2_labels = {"negative": 0, "positive": 1} - - response, usage = self.predict( - prompt=prompt, text=self.get_text_for_datum(datum) - ) - response = response.lower() - if self.use_grammar: - # model output is from label space - answer_label = sst2_labels[response] - else: - answer_label = None - for label in sst2_labels.keys(): - if label in response: - answer_label = sst2_labels[label] - break - else: - logger.warning(f"Invalid answer: {response}") - return "failed", usage - - classification_result = ( - "incorrect" if answer_label != datum["label"] else "correct" - ) - return classification_result, usage - - def _aggregate_result(self, results: list[str]) -> float: - num_correct_results = sum(1 for result in results if result == "correct") - accuracy = num_correct_results / len(results) - return accuracy - - @property - def metric_name(self): - return "accuracy" - - @property - def base_prompt(self): - # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning - return """In this task, you are given sentences from movie reviews. The task is to classify a sentence as 'positive' if the sentiment of the sentence is positive or as 'negative' if the sentiment of the sentence is negative. Return label only without any other text.""" - - -def grammar_continuous_with_arbitrary_end(sequence: list[str], quote: str = '"'): - if len(sequence) == 0: - return "" - - if quote is None: - quote = "" - - escape_item = lambda text: text.replace('"', '\\"') - - grammar = quote + escape_item(sequence[0]) + quote - for idx, item in enumerate(sequence[1:]): - new_item = "(" + quote + escape_item(item) + quote + ")?" - if idx == 0: - grammar += new_item - else: - grammar = grammar[: -2 * idx] + new_item + grammar[-2 * idx :] - return grammar - - -# a grammar in GBNF notation which takes the context into account -# NOTE: for some reason the program crashes if the grammar is an instance attribute of the task class -@lru_cache -def extractive_qa_grammar_fn(context: str, verbose: bool = False): - grammar_str = "(\n" - tokens = re.split(r"(\W+)", context.strip())[:-1] - - for i in range(len(tokens)): - if i > 0: - grammar_str += "\n|\n" - grammar_str += "(" + grammar_continuous_with_arbitrary_end(tokens[i:]) + ")" - grammar_str += "\n)" - return LlamaGrammar.from_string(f"root ::= {grammar_str}", verbose=verbose) - - -class QuestionAnswering(Task): - shorthand = "qa" - - def __init__(self, model, options: Namespace): - - self.metric = load_metric("squad") - - super().__init__( - model, - "squad", - "squad", - use_grammar=options.use_grammar, - evaluation_strategy_cls=get_evaluation_strategy( - options.evaluation_strategy - ), - validation_split=f"train[:{5 if options.debug else 200}]", - test_split="validation[:20]" if options.debug else "validation", - ) - - def predict(self, prompt: str, datum: DatasetDatum): - # run model for inference - grammar = None - if self.use_grammar: - # context-sensitive grammar - context = datum["context"] - try: - grammar = extractive_qa_grammar_fn(context) - except Exception as e: - logger.exception( - "Could not create grammar for context (potentially maximum recursion depth exceeded), therefore we do not use a grammar for this sample.\nContext (with length %d): %s", - len(context), - context, - exc_info=e, - ) - - response, usage = self.model( - system_message=SYSTEM_MESSAGE, - prompt=prompt, - prompt_appendix=self.get_text_for_datum(datum), - grammar=grammar, - ) - - if not self.use_grammar: - # we postprocess the model output to return as answer - response = response.strip() - - return response, usage - - def get_text_for_datum(self, datum: DatasetDatum) -> str: - return ( - "\nContext: " - + '"' - + datum["context"] - + '"' - + "\nQuestion: " - + '"' - + datum["question"] - + '"' - ) - - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - answer, usage = self.predict(prompt, datum) - # input(f'*** PREDICTION ***\n\tContext: {datum["context"]}\n\tQuestion: {datum["question"]}\n\tAnswer: {answer}\n\tGold answer: {datum["answers"]["text"][0]}') - # TODO check if answer is lower-cased in metric computation - - result = self.metric.compute( - predictions=[{"prediction_text": answer, "id": datum["id"]}], - references=[{"answers": datum["answers"], "id": datum["id"]}], - ) - - return result["f1"] / 100, usage - - def _aggregate_result(self, results: list[float]) -> float: - return sum(results) / len(results) - - def evaluate( - self, - prompt: str, - dataset: Dataset, - parent_histories: list[list[float]] | None = None, - ): - def replace_symbol_for_grammar(sample: DatasetDatum): - symbol_replacement_mapping = { - "\u2013": "-", - "\u2014": "-", - } - symbol_replacement_mapping = dict( - (re.escape(k), v) for k, v in symbol_replacement_mapping.items() - ) - symbol_replacement_pattern = re.compile( - "|".join(symbol_replacement_mapping.keys()) - ) - replace_fn = lambda text: symbol_replacement_pattern.sub( - lambda m: symbol_replacement_mapping[re.escape(m.group(0))], text - ) - sample["context"] = replace_fn(sample["context"]) - sample["answers"]["text"] = [ - replace_fn(text) for text in sample["answers"]["text"] - ] - return sample - - if self.use_grammar: - # NOTE: the LlamaGrammar has issues with symbol '–' therefore we replace all occurences with '-' (hyphen) - dataset = dataset.map(replace_symbol_for_grammar, desc="Replacing symbols") - return super().evaluate(prompt, dataset, parent_histories=parent_histories) - - @property - def metric_name(self): - return "f1" - - @property - def base_prompt(self): - # TODO find good prompt - return """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context.""" - - -tasks = {task.shorthand: task for task in Task.__subclasses__()} - - -def get_task(name: str, evaluation_model: LLMModel, options: Namespace): - if name not in tasks: - raise ValueError("Model %s does not exist", name) - return tasks[name](evaluation_model, options) - - -argument_parser.add_argument("--debug", "-d", action="store_true", default=None) -argument_group = argument_parser.add_argument_group("Task arguments") -argument_group.add_argument("--use-grammar", "-g", action="store_true") -argument_group.add_argument( - "--evaluation-strategy", - choices=EvaluationStrategyKeys.__args__, - default="simple", -) -argument_group.add_argument( - "--task", "-t", type=str, required=True, choices=tasks.keys() -) diff --git a/evoprompt/utils.py b/evoprompt/utils.py index ac8e6fe46de885debb083d861afb9d42fb0c456a..3eefa2e9c709a75bff442f70880df3f1c0336347 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -140,3 +140,8 @@ class log_calls: arguments[argument_name] = value return arguments + +def get_all_subclasses(cls): + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] + ) \ No newline at end of file diff --git a/main.py b/main.py index 6382480109a91f44a949b46f476b031e6da39906..6a3d6b51016bb33b0ea0b2375e0f4b23350f4a12 100644 --- a/main.py +++ b/main.py @@ -72,7 +72,7 @@ if __name__ == "__main__": case "openai": evaluation_model = Llama(options) - task = get_task(options.task, evaluation_model, options) + task = get_task(options.task, evaluation_model, **options.__dict__) logger.info( f"Running with task {task.__class__.__name__} on dataset {task.validation_dataset.info.dataset_name}" )