From 8355c2cbbe46fcf1799c6e9640581c685a4da718 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Fri, 20 Sep 2024 12:13:22 +0200 Subject: [PATCH] Allow to skip bad evolutions based on judge model --- evoprompt/models.py | 12 -------- evoprompt/optimization.py | 61 ++++++++++++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index de88e78..8f717a1 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -399,18 +399,6 @@ class OpenAiChat(ChatModel, LLMModel): argument_group = argument_parser.add_argument_group("Model arguments") -argument_group.add_argument( - "--evolution-engine", - type=str, - choices=LLMModel.registered_models.keys(), - default="llama", -) -argument_group.add_argument( - "--judge-engine", - type=str, - default=None, - choices=LLMModel.registered_models.keys(), -) argument_group.add_argument( "--disable-cache", action="store_true", diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index 70b3139..8153870 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -1,20 +1,22 @@ -from collections import OrderedDict import json import logging -from pathlib import Path import re +from collections import OrderedDict +from difflib import Differ +from pathlib import Path from typing import Any, Literal, NamedTuple, Optional, TypedDict -from textual.app import App, ComposeResult -from textual.binding import Binding -from textual.containers import ScrollableContainer -from textual.widgets import Collapsible, Footer, Label, TextArea, Static +import wandb from rich.panel import Panel from rich.rule import Rule from tabulate import tabulate +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import ScrollableContainer +from textual.widgets import Collapsible, Footer, Label, Static, TextArea from tqdm import tqdm, trange -import wandb +from evoprompt.cli import argument_parser from evoprompt.models import ChatMessages, LLMModel from evoprompt.opt_types import ModelUsage, OptTypeEncoder, Prompt from evoprompt.task import Task @@ -160,6 +162,8 @@ class PromptOptimization: self.judge_model = judge_model self.run_options = run_options + self.skip_bad_evolutions = run_options.get("skip_bad_evolutions", False) + def evaluate_prompt(self, prompt: str, parents: tuple[Prompt] | None = None): parent_histories = ( [parent.evaluation_history for parent in parents] @@ -359,6 +363,7 @@ class PromptOptimization: for message in history if message["role"] in ["user", "assistant"] ) + # TODO What if the history does not exist (is empty), i.e., for the first step in de-cot? prompt = ( f"Context: {history_str}\nInstruction: {instruction}\nResponse: {response}" ) @@ -369,6 +374,7 @@ class PromptOptimization: "Wrap the answer with tags <judgement> and </judgement>. " "Please also add an explanation for your judgement." ) + # input(f"System message:\n{system_message}\n\nPrompt:\n{prompt}\n") judgement_response, _, _, _ = self.judge_model.create_completion( system_message=system_message, prompt=prompt, @@ -392,9 +398,13 @@ class PromptOptimization: if judge_happy: return Judgement(response, response, happy=True, skip=False) - logger.info(f"Prompt judged as bad, letting user take action.") + if self.skip_bad_evolutions: + # skip samples judged as bad + logger.info(f"Prompt judged as bad, skipping.") + return Judgement(response, response, happy=False, skip=True) # let user skip or correct the response in an interactive way + logger.info(f"Prompt judged as bad, letting user take action.") editor = ResponseEditor( instruction, response, @@ -406,10 +416,19 @@ class PromptOptimization: if editor.skip: logger.info("User skipped prompt.") else: + delta = Differ().compare( + response.splitlines(), editor.modified_response.splitlines() + ) logger.info( - "User corrected prompt:\n'%s'\n -> \n'%s'", - response, - editor.modified_response, + "User corrected prompt (delta):\n%s", + "\n".join( + line + for line in delta + if line.startswith("+") or line.startswith("-") + ), + # "User corrected prompt:\n'%s'\n -> \n'%s'", + # response, + # editor.modified_response, ) return Judgement( @@ -429,3 +448,23 @@ def load_snapshot(path: Path): snapshot["T"], snapshot["N"], ) + + +argument_group = argument_parser.add_argument_group("Optimization arguments") +argument_group.add_argument( + "--evolution-engine", + type=str, + choices=LLMModel.registered_models.keys(), + default="llama", +) +argument_group.add_argument( + "--judge-engine", + type=str, + default=None, + choices=LLMModel.registered_models.keys(), +) +argument_parser.add_argument( + "--skip-bad-evolutions", + action="store_true", + help="Skip bad evolutions as judged by the judge engine (needs judge engine to be set)", +) -- GitLab