diff --git a/evoprompt/models.py b/evoprompt/models.py index 21a3e15959a8f798afa0b02bc2144190b3e08804..52d6417644b4207ac34de106cd4eafcd4e197aa9 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -63,9 +63,6 @@ class LLMModel(ABC): prompt: str, *, use_cache: bool = False, - prompt_appendix: str = "", - prompt_prefix: str = "", - prompt_suffix: str = "", stop: str = None, history: ChatMessages | None = None, **kwargs: Any, @@ -182,15 +179,10 @@ class Llama(LLMModel): prompt: str, *, use_cache: bool = False, - prompt_appendix: str = "", - prompt_prefix: str = "", - prompt_suffix: str = "", stop: str = None, history: ChatMessages | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: - # create prompt - prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix messages = [self._get_user_message(prompt)] if system_message is not None: prompt = system_message + prompt @@ -268,9 +260,6 @@ class ChatModel: prompt: str, *, use_cache: bool = False, - prompt_appendix: str = "", - prompt_prefix: str = "", - prompt_suffix: str = "", stop: str = None, history: ChatMessages | None = None, **kwargs: Any, @@ -283,9 +272,6 @@ class ChatModel: messages = history else: messages = [] - - # create prompt - prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix messages += [self._get_user_message(prompt)] reponse, usage = self._create_completion( diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index c43330e5957ec33e391240dd1d0acec9dd3bf4f5..98eed474c2036e5adacfc10e6b9a1c1fd2d0ee7b 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -24,9 +24,10 @@ def paraphrase_prompts( prompt: str, n: int, unique_paraphrases: bool = False, + consider_for_unique_paraphrases: list[str] = [], max_tries: int = 10, return_only_unique_paraphrases: bool = False, -) -> tuple[list, ModelUsage]: +) -> tuple[list[str], ModelUsage]: total_usage = ModelUsage() paraphrases = [] num_tries = 0 @@ -36,16 +37,14 @@ def paraphrase_prompts( num_tries += 1 paraphrase, _, usage = model.create_completion( system_message=PARAPHRASE_PROMPT, - prompt=prompt, - prompt_prefix=' Instruction: "', - prompt_suffix='"', + prompt=f"Instruction: {prompt}", ) total_usage += usage if "<prompt>" in paraphrase: paraphrase = paraphrase.split("<prompt>")[1].split("</prompt>")[0] if ( not unique_paraphrases - or paraphrase not in paraphrases + or paraphrase not in (paraphrases + consider_for_unique_paraphrases) or max_tries - num_tries == n - len(paraphrases) ): # add paraphrase only if unique_paraphrases==True and (if not already present or if the attempts run out) @@ -124,6 +123,7 @@ class PromptOptimization: top_prompts[promptindex_to_paraphrase], n=1, unique_paraphrases=True, + consider_for_unique_paraphrases=initial_population, ) self.total_evolution_usage += paraphrase_usage initial_population += paraphrases diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py index fbd2d60867e89ee95bb532a7ee2a8d3283de8097..9cfa194a11fe05d8fdd4dc8dd0d9a8ed73c6622d 100644 --- a/evoprompt/task/question_answering.py +++ b/evoprompt/task/question_answering.py @@ -1,7 +1,7 @@ import logging import re from abc import abstractmethod -from functools import cache, lru_cache +from functools import lru_cache from typing import Iterable from datasets import Dataset @@ -55,10 +55,9 @@ class QuestionAnswering(Task): self.metric = load_metric("squad") - @cache def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: # context-sensitive grammar - context = self._get_context_from_datum(datum) + context = self._get_context_for_datum(datum) try: return extractive_qa_grammar_fn(context) except Exception as e: @@ -69,12 +68,14 @@ class QuestionAnswering(Task): exc_info=e, ) + @staticmethod + def _get_generation_prefix(): + return "Answer: " + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - context = self._get_context_from_datum(datum) - question = self._get_question_from_datum(datum) - return ( - "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"' - ) + context = self._get_context_for_datum(datum) + question = self._get_question_for_datum(datum) + return f"Context: {context}\nQuestion: {question}" def _get_demonstration_sample_ids( self, dataset: Dataset, n_evaluation_demo: int @@ -87,39 +88,41 @@ class QuestionAnswering(Task): pass @abstractmethod - def _get_context_from_datum(self, datum: DatasetDatum): + def _get_context_for_datum(self, datum: DatasetDatum): pass @abstractmethod - def _get_question_from_datum(self, datum: DatasetDatum): + def _get_question_for_datum(self, datum: DatasetDatum): pass - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): + def _parse_response(self, response: str) -> str: + if self.use_grammar: + return response + # if we do not use a grammar, we need to extract the answer from the response + # otherwise the answer is from the context as enforced by the grammar + prefix_to_match = self._get_generation_prefix().replace(" ", r"\s?") + matches = re.findall( + # regex that matches class labels after "Response: " + rf"(?:{prefix_to_match})?(.+)", + response.splitlines()[-1], + flags=re.IGNORECASE, + ) + # look for an answer in the response, if not found, use whole response + if matches: + return matches[-1] + else: + return response + + def _evaluate_sample(self, response: str, datum: DatasetDatum) -> float: _id = self._get_id_from_datum(datum) gold_answers = self._get_gold_label_for_datum(datum) - response, usage = self.predict(prompt, datum) response = response.lower() - if not self.use_grammar: - # if we do not use a grammar, we need to extract the answer from the response - # otherwise the answer is from the context as enforced by the grammar - matches = re.findall( - # regex that matches class labels after "Response: " - rf"(?:Response:\s?)?(.+)", - response.splitlines()[-1], - flags=re.IGNORECASE, - ) - # look for an answer in the response, if not found, use whole response - if matches: - answer = matches[-1] - else: - answer = response - result = self.metric.compute( - predictions=[{"prediction_text": answer, "id": _id}], + predictions=[{"prediction_text": response, "id": _id}], references=[{"answers": gold_answers, "id": _id}], ) - return result["f1"] / 100, usage + return result["f1"] def _aggregate_result(self, results: list[float]) -> float: return sum(results) / len(results) @@ -179,10 +182,10 @@ class SQuAD(QuestionAnswering): def load_test_set(self, test_dataset: str, test_split: str | None): return super().load_test_set("squad", "validation") - def _get_context_from_datum(self, datum: DatasetDatum): + def _get_context_for_datum(self, datum: DatasetDatum): return datum["context"] - def _get_question_from_datum(self, datum: DatasetDatum): + def _get_question_for_datum(self, datum: DatasetDatum): return datum["question"] def _get_id_from_datum(self, datum: DatasetDatum): diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py index 6886ed18ce89032b4f19b56f99155abab7f222b0..e0126a0055837b3d4905839358d031fd6bf3814b 100644 --- a/evoprompt/task/sentiment_analysis.py +++ b/evoprompt/task/sentiment_analysis.py @@ -22,6 +22,10 @@ class SentimentAnalysis(TextClassification): def _get_label_mapping() -> Mapping: return {"negative": 0, "positive": 1} + @staticmethod + def _get_generation_prefix(): + return "Sentiment: " + class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis): shorthand = "sst2-hf" diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py index 4d58aebf61d4d9d167a6d623839b674b1c0a95e3..69231a80fe7464405f7382a1a82e4d9eb6a360fd 100644 --- a/evoprompt/task/simplification.py +++ b/evoprompt/task/simplification.py @@ -23,6 +23,10 @@ class Simplification(TextGeneration): references=[gold_label], )["sari"] + @staticmethod + def _get_generation_prefix(): + return "Simplification: " + @property def metric_name(self): return "sari" diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py index 5ded8dada78f412afa15f7362a7813b9f32fb487..5db5edccc5b8f2f2951628276c6a925d4d909411 100644 --- a/evoprompt/task/subjectivity_classification.py +++ b/evoprompt/task/subjectivity_classification.py @@ -37,6 +37,10 @@ class Subj(BasePromptsFromJsonMixin, TextClassification): def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: return datum["label"] + @staticmethod + def _get_generation_prefix(): + return "Subjectivity: " + @staticmethod @cache def _get_label_mapping() -> Mapping: diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py index 3d45aefdca87d03d30025276cc5a391b596dee74..3d4789a5c84c27300e25115606c0149124f23d54 100644 --- a/evoprompt/task/summarization.py +++ b/evoprompt/task/summarization.py @@ -17,9 +17,16 @@ class Summarization(TextGeneration): def compute_metric(self, datum: DatasetDatum, prediction: str) -> float: gold_label = self._get_gold_label_for_datum(datum) - return self.metric.compute(predictions=[prediction], references=[gold_label])[ - "rougeL" - ] + return ( + self.metric.compute(predictions=[prediction], references=[gold_label])[ + "rougeL" + ] + * 100 + ) + + @staticmethod + def _get_generation_prefix(): + return "Summary: " @property def metric_name(self): diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py index 860a620e7be456e4e611001a9e39e265b3327d20..d93a66ea6976bfda6227dcad981360e85bb34668 100644 --- a/evoprompt/task/task.py +++ b/evoprompt/task/task.py @@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod from collections import deque from dataclasses import KW_ONLY, dataclass from statistics import mean -from typing import Iterable, Literal +from typing import Any, Iterable, Literal from datasets import Dataset, load_dataset from llama_cpp import LlamaGrammar @@ -344,10 +344,7 @@ class Task(metaclass=ABCMeta): # TODO grammar also depends on prompt and vice-versa -> what are good labels? response, _, usage = self.model.create_completion( system_message=SYSTEM_MESSAGE, - # TODO Allow to modify prompt construction in subclasses prompt=prompt, - prompt_suffix="\n" + self._get_prompt_text_for_datum(datum), - prompt_appendix="\nResponse: ", # grammar can be applied to constrain the model output grammar=self._get_grammar(datum) if self.use_grammar else None, # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result @@ -360,19 +357,25 @@ class Task(metaclass=ABCMeta): return response, usage + def _build_prompt(self, prompt: str, datum: DatasetDatum) -> str: + prompt = f"{prompt}\n\n{self._get_prompt_text_for_datum(datum)}\n{self._get_generation_prefix()}" + return prompt + + @abstractmethod + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: ... + + @abstractmethod + def _get_generation_prefix() -> str: ... + @abstractmethod def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: pass @abstractmethod - def _evaluate_sample( - self, prompt: str, datum: DatasetDatum - ) -> tuple[str, ModelUsage]: - pass + def _evaluate_sample(self, response: str, datum: DatasetDatum) -> Any: ... @abstractmethod - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - pass + def _parse_response(self, response: str) -> str: ... @abstractmethod def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: @@ -404,13 +407,20 @@ class Task(metaclass=ABCMeta): # augment prompt with demonstration samples prompt += "".join( [ - f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_gold_label_generation_for_datum(datum)}" + f"\n\n{self._get_prompt_text_for_datum(datum)}\n{self._get_generation_prefix()}{self._get_gold_label_generation_for_datum(datum)}" for datum in self.demonstration_samples ] ) for datum in dataset_iterator: - result, usage = self._evaluate_sample(prompt, datum) + # build prompt for current sample + prompt_for_datum = self._build_prompt(prompt, datum) + # run prediction + response, usage = self.predict(prompt=prompt_for_datum, datum=datum) + # parse response + response = self._parse_response(response=response) + # evaluate response + result = self._evaluate_sample(response=response, datum=datum) results.append(result) current_metric = self._aggregate_result(results) dataset_iterator.set_postfix({self.metric_name: f"{current_metric:.2f}"}) diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py index ea7f3821dd62b1a1190a63a5e22ee7ed70388a82..1e19766de03ce76ec451939452d9ebcd06ce22b4 100644 --- a/evoprompt/task/text_classification.py +++ b/evoprompt/task/text_classification.py @@ -7,6 +7,7 @@ from typing import Any, Mapping from datasets import Dataset from llama_cpp import LlamaGrammar +from evoprompt.opt_types import ModelUsage from evoprompt.task import Task from evoprompt.task.task import DatasetDatum @@ -14,23 +15,20 @@ logger = logging.getLogger(__name__) class TextClassification(Task): - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - gold_label = self._get_gold_label_for_datum(datum) + def _parse_response(self, response: str): class_mapping = self._get_label_mapping() - response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() if self.use_grammar: # model output is from label space - answer_label = class_mapping[response] + return response else: matches = re.findall( - # regex that matches class labels after "Response: " - rf"Response: ({'|'.join(class_mapping.keys())})", + # regex that matches class labels after the generation prefix + rf"{self._get_generation_prefix()}({'|'.join(class_mapping.keys())})", response, flags=re.IGNORECASE, ) if matches: - answer_label = class_mapping[matches[-1]] + return matches[-1] else: # look for a label in the response, if not found, return failed matches = re.findall( @@ -40,13 +38,22 @@ class TextClassification(Task): flags=re.IGNORECASE, ) if matches: - answer_label = class_mapping[matches[-1]] + return matches[-1] else: - logger.warning(f"Invalid answer: {response}") - return "failed", usage + return response - classification_result = "incorrect" if answer_label != gold_label else "correct" - return classification_result, usage + def _evaluate_sample( + self, response: str, datum: DatasetDatum + ) -> tuple[str, ModelUsage]: + gold_label = self._get_gold_label_for_datum(datum) + class_mapping = self._get_label_mapping() + response = response.lower() + if response not in class_mapping: + logger.warning(f"Invalid answer: {response}") + return "failed" + prediction = class_mapping[response] + classification_result = "incorrect" if prediction != gold_label else "correct" + return classification_result def _get_demonstration_sample_ids( self, dataset: Dataset, n_evaluation_demo: int @@ -63,6 +70,13 @@ class TextClassification(Task): sample_ids += sample_ids_for_label return sample_ids + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + # NOTE it seems that quotes in the prompt make it worse + return f"{self._get_input_prefix()}{self._get_text_for_datum(datum)}" + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: ... + # NOTE cannot be cached since grammar is not picklable def _get_grammar(self, datum: DatasetDatum, verbose: bool = False): return LlamaGrammar.from_string( @@ -72,18 +86,14 @@ class TextClassification(Task): verbose=verbose, ) - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - # TODO do we need quotes? - return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' + @staticmethod + def _get_input_prefix(): + return "Text: " @abstractmethod def _get_label_mapping(self) -> Mapping: pass - @abstractmethod - def _get_text_for_datum(self, datum: DatasetDatum) -> str: - pass - @abstractmethod def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str: pass @@ -98,7 +108,7 @@ class TextClassification(Task): def _aggregate_result(self, results: list[str]) -> float: num_correct_results = sum(1 for result in results if result == "correct") - accuracy = num_correct_results / len(results) + accuracy = num_correct_results / len(results) * 100 return accuracy @property diff --git a/evoprompt/task/text_generation.py b/evoprompt/task/text_generation.py index 7e20cd67f5f1ade917b5180cb952d1e379720048..81a81f6866ef1b4fdd9af732852645ee5a648b16 100644 --- a/evoprompt/task/text_generation.py +++ b/evoprompt/task/text_generation.py @@ -13,25 +13,39 @@ logger = logging.getLogger(__name__) class TextGeneration(Task): - def _evaluate_sample(self, prompt: str, datum: DatasetDatum): - response, usage = self.predict(prompt=prompt, datum=datum) - response = response.lower() - return self.compute_metric(datum, response), usage + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if self.use_grammar: + raise ValueError("Text generation tasks do not support grammars") + + def _parse_response(self, response: str) -> str: + return response + + def _evaluate_sample(self, response: str, datum: DatasetDatum) -> float: + return self.compute_metric(datum, response.lower()) + + def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: + # NOTE it seems that quotes in the prompt make it worse + return f"{self._get_input_prefix()}{self._get_text_for_datum(datum)}" + + @abstractmethod + def _get_text_for_datum(self, datum: DatasetDatum) -> str: ... def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar: # there is no grammar for open text generation return None + @staticmethod + def _get_input_prefix(): + return "Text: " + def _get_demonstration_sample_ids( self, dataset: Dataset, n_evaluation_demo: int ) -> Iterable[int]: # select demonstration samples uniformly at random return get_rng().choice(len(dataset), n_evaluation_demo, replace=False) - def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str: - # TODO do we need quotes? - return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"' - def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str: return self._get_gold_label_for_datum(datum) diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py index 4103a43242c0f9adbf83adb75bf73cac16e19d34..e16aef39ca36b1aaaed81de4471ccf4c4c79dc6c 100644 --- a/evoprompt/task/topic_classification.py +++ b/evoprompt/task/topic_classification.py @@ -8,7 +8,13 @@ from evoprompt.task import TextClassification from evoprompt.task.task import DatasetDatum -class AGNews(BasePromptsFromJsonMixin, TextClassification): +class TopicClassification(TextClassification): + @staticmethod + def _get_generation_prefix(): + return "Topic: " + + +class AGNews(BasePromptsFromJsonMixin, TopicClassification): shorthand = "agn" base_prompts_file = "evoprompt/initial_prompts/agnews/prompts.json" @@ -35,7 +41,7 @@ class AGNews(BasePromptsFromJsonMixin, TextClassification): return dict(zip(classes, range(len(classes)))) -class TREC(BasePromptsFromJsonMixin, TextClassification): +class TREC(BasePromptsFromJsonMixin, TopicClassification): shorthand = "trec" base_prompts_file = "evoprompt/initial_prompts/trec/prompts.json"