diff --git a/evoprompt/evolution.py b/evoprompt/evolution.py index 49050b594e62d899bf4e89aa5517fbdbf85b8e2f..ec6757988c54a58e68d6921dd96d952fe22e4ffb 100644 --- a/evoprompt/evolution.py +++ b/evoprompt/evolution.py @@ -1,9 +1,8 @@ import logging -from abc import ABCMeta, abstractmethod import re +from abc import ABCMeta, abstractmethod from typing import Any -from evoprompt.utils import get_rng from tqdm import trange from evoprompt.cli import argument_parser @@ -12,7 +11,7 @@ from evoprompt.opt_types import ModelUsage, Prompt from evoprompt.optimization import PromptOptimization from evoprompt.task import Task from evoprompt.template_de import get_de_prompt_template -from evoprompt.utils import get_all_subclasses, log_calls +from evoprompt.utils import get_all_subclasses, get_rng, log_calls logger = logging.getLogger(__name__) @@ -47,12 +46,14 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta): task: Task, evolution_model: LLMModel, evaluation_model: LLMModel, + judge_model: LLMModel, run_options: dict[str, Any] = {}, ) -> None: super().__init__( task=task, evolution_model=evolution_model, evaluation_model=evaluation_model, + judge_model=judge_model, run_options=run_options, ) self.use_evolution_demo = run_options.get("use_evolution_demo", False) @@ -180,9 +181,10 @@ class GeneticAlgorithm(EvolutionAlgorithm): # Based on this two-step process, we design instructions, guiding LLMs to # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1. - evolved_prompt, _, usage = self.evolution_model.create_completion( + filled_prompt = GA_PROMPT.format(prompt1=prompt_1, prompt2=prompt_2) + evolved_prompt, messages, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=GA_PROMPT.format(prompt1=prompt_1, prompt2=prompt_2), + prompt=filled_prompt, ) if "<prompt>" in evolved_prompt: evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0] @@ -193,6 +195,7 @@ class GeneticAlgorithm(EvolutionAlgorithm): prompt_2, evolved_prompt, ) + self.judge_step(filled_prompt, evolved_prompt, messages) return evolved_prompt, usage @@ -240,14 +243,17 @@ class DifferentialEvolution(EvolutionAlgorithm): prompts_current_evolution, key=lambda prompt: prompt.score ) - evolved_prompt, _, usage = self.evolution_model.create_completion( + filled_prompt = get_de_prompt_template( + self.use_evolution_demo, self.task + ).format( + prompt1=prompt_1, + prompt2=prompt_2, + prompt3=best_prompt_current_evolution, + basic_prompt=prompts_current_evolution[current_iteration], + ) + evolved_prompt, messages, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=get_de_prompt_template(self.use_evolution_demo, self.task).format( - prompt1=prompt_1, - prompt2=prompt_2, - prompt3=best_prompt_current_evolution, - basic_prompt=prompts_current_evolution[current_iteration], - ), + prompt=filled_prompt, ) matches = re.findall( # regex that matches any characters between last pair of <prompt></prompt>, also if </prompt> is missing @@ -275,6 +281,8 @@ class DifferentialEvolution(EvolutionAlgorithm): evolved_prompt, ) + self.judge_step(filled_prompt, evolved_prompt, messages) + return evolved_prompt, usage @log_calls("Performing update for DE") @@ -314,15 +322,16 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): ) messages = None - for idx, prompt_template in enumerate(DE_COT_PROMPTS): + for idx, prompt in enumerate(DE_COT_PROMPTS): + filled_prompt = prompt.format( + prompt1=prompt_1, + prompt2=prompt_2, + prompt3=best_prompt_current_evolution, + basic_prompt=prompts_current_evolution[current_iteration], + ) response, messages, usage = self.evolution_model.create_completion( system_message=SYSTEM_MESSAGE, - prompt=prompt_template.format( - prompt1=prompt_1, - prompt2=prompt_2, - prompt3=best_prompt_current_evolution, - basic_prompt=prompts_current_evolution[current_iteration], - ), + prompt=filled_prompt, history=messages, stop="</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None, ) @@ -332,6 +341,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution): messages, response, ) + self.judge_step(filled_prompt, response, history=messages) # input(messages) # input(response) diff --git a/evoprompt/models.py b/evoprompt/models.py index 52d6417644b4207ac34de106cd4eafcd4e197aa9..b2877e7e2241247820db2ede765353f7036f982d 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -66,7 +66,21 @@ class LLMModel(ABC): stop: str = None, history: ChatMessages | None = None, **kwargs: Any, - ) -> tuple[str, ModelUsage]: ... + ) -> tuple[str, list[dict[str, str]], ModelUsage]: + # create prompt + prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix + messages = [self._get_user_message(prompt)] + model_input = self.build_model_input(prompt, system_message, messages, history) + + reponse, usage = self._create_completion( + **model_input, + stop=stop, + use_cache=use_cache, + **kwargs, + ) + + messages.append(self._get_assistant_message(reponse)) + return reponse, messages, usage def _get_user_message(self, content: str): return { @@ -342,11 +356,16 @@ class OpenAiChat(ChatModel, LLMModel): argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument( "--evolution-engine", - "-e", type=str, choices=LLMModel.models.keys(), default="llama", ) +argument_group.add_argument( + "--judge-engine", + type=str, + default=None, + choices=LLMModel.models.keys(), +) argument_group.add_argument( "--disable-cache", action="store_true", diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 98eed474c2036e5adacfc10e6b9a1c1fd2d0ee7b..f7b730fb7b8649a6a950e31c615e24249c68d0f2 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -78,11 +78,13 @@ class PromptOptimization: task: Task, evolution_model: LLMModel, evaluation_model: LLMModel, + judge_model: LLMModel, run_options: dict[str, Any] = {}, ) -> None: self.task = task self.evolution_model = evolution_model self.evaluation_model = evaluation_model + self.judge_model = judge_model self.run_options = run_options def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None): @@ -245,6 +247,25 @@ class PromptOptimization: cls=OptTypeEncoder, ) + def judge_step( + self, instruction: str, response: str, history: list[dict[str, str]] + ): + print("\n" * 20) + print( + "Instruction:", instruction, "\nResponse:", response, "\nHistory:", history + ) + + # TODO: judge the actual response + judge_happy = False + + logger.info( + f"{self.judge_model.__class__.__name__} judged the response as {'good' if judge_happy else 'bad'}" + ) + if judge_happy: + return True + + logger.info(f"Prompt judged as bad. Letting User change the prompt.") + def load_snapshot(path: Path): import json diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index c83a698b57a52a9f05c5ba7d2ba0dfd8dd43233a..e628c1a796d06a7d14507c3db696da0cc828a5e9 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -1,20 +1,16 @@ -from argparse import Namespace -from typing import Literal - from evoprompt.cli import argument_parser from evoprompt.models import LLMModel +from evoprompt.task.question_answering import QuestionAnswering +from evoprompt.task.sentiment_analysis import SentimentAnalysis +from evoprompt.task.simplification import ASSET, Simplification +from evoprompt.task.subjectivity_classification import Subj +from evoprompt.task.summarization import SAMSum, Summarization # make sure to run definitions of subclasses of Task first from evoprompt.task.task import EvaluationStrategyKey, Task -from evoprompt.task.question_answering import QuestionAnswering from evoprompt.task.text_classification import TextClassification -from evoprompt.task.sentiment_analysis import SentimentAnalysis -from evoprompt.task.topic_classification import AGNews, TREC -from evoprompt.task.subjectivity_classification import Subj from evoprompt.task.text_generation import TextGeneration -from evoprompt.task.summarization import Summarization, SAMSum -from evoprompt.task.simplification import Simplification, ASSET - +from evoprompt.task.topic_classification import TREC, AGNews from evoprompt.utils import get_all_subclasses # at this point, we assume that all subclasses of Task have been defined @@ -34,7 +30,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument( - "--task", "-t", type=str, required=True, choices=sorted(tasks.keys()) + "--task", type=str, required=True, choices=sorted(tasks.keys()) ) argument_group.add_argument("--use-grammar", "-g", action="store_true") argument_group.add_argument( diff --git a/main.py b/main.py index e0b8faedc0485ebce7a65ac35969faa971580a74..a6474b2408f0ab397e98a56bef5e9d4f1cf894f3 100644 --- a/main.py +++ b/main.py @@ -63,13 +63,20 @@ if __name__ == "__main__": # # set up evolution model evolution_model = LLMModel.get_model(options.evolution_engine, options=options) - match options.evolution_engine: case "llama": logger.info("Using Llama as the evolution engine") case "openai": logger.info(f"Using {options.openai_model} as the evolution engine") + judge_model: LLMModel + if options.judge_engine is not None: + judge_model = LLMModel.get_model(options.judge_engine, options=options) + logger.info(f"Using {options.judge_engine} as the judge engine") + else: + judge_model = evolution_model + logger.info("Using the same model for judging as for evolution") + # set up evaluation model # NOTE currenty we always stick to Llama as evaluation engine # TODO allow to set separate engine and model for evaluation? @@ -94,6 +101,7 @@ if __name__ == "__main__": task=task, evolution_model=evolution_model, evaluation_model=evaluation_model, + judge_model=judge_model, run_options=options.__dict__, ) optimizer.run(10, debug=debug)