From 4fa8266f7e330c3c397920e4b83c8e3d31b509c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Fri, 25 Oct 2024 18:11:48 +0200 Subject: [PATCH] allow setting evaluation engine via parameter --- evoprompt/models.py | 8 +++++++- evoprompt/optimization.py | 8 +++++++- main.py | 15 ++++++++------- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 8360c9e..248bf94 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -564,6 +564,11 @@ class HfChat(ChatModel, LLMModel): else: model_call_kwargs["do_sample"] = False + if "max_tokens" in model_call_kwargs: + model_call_kwargs["max_completion_tokens"] = model_call_kwargs.pop( + "max_tokens" + ) + # input( # f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'" # ) @@ -645,6 +650,7 @@ class OpenAIChat(ChatModel, LLMModel): self, messages: ChatMessages, *, + grammar: None = None, use_cache: bool, stop: str | None, max_tokens: int | None, @@ -656,7 +662,7 @@ class OpenAIChat(ChatModel, LLMModel): "model": self.model_name, "messages": messages, "stop": stop, - "max_completion_tokens": max_tokens if max_tokens is not None else 1024, + "max_tokens": max_tokens if max_tokens is not None else 1024, } if use_randomness: if temperature is None: diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py index cc3710a..782de32 100644 --- a/evoprompt/optimization.py +++ b/evoprompt/optimization.py @@ -6,7 +6,6 @@ from difflib import Differ from pathlib import Path from typing import Any, Optional -import wandb from rich.panel import Panel from rich.rule import Rule from tabulate import tabulate @@ -16,6 +15,7 @@ from textual.containers import ScrollableContainer from textual.widgets import Collapsible, Footer, Label, Static, TextArea from tqdm import tqdm, trange +import wandb from evoprompt.cli import argument_parser from evoprompt.models import ChatMessages, LLMModel from evoprompt.opt_types import ( @@ -464,6 +464,12 @@ def load_snapshot(path: Path): argument_group = argument_parser.add_argument_group("Optimization arguments") +argument_group.add_argument( + "--evaluation-engine", + type=str, + choices=LLMModel.registered_models.keys(), + required=False, +) argument_group.add_argument( "--evolution-engine", type=str, diff --git a/main.py b/main.py index 66a16ac..338b1eb 100644 --- a/main.py +++ b/main.py @@ -102,14 +102,15 @@ if __name__ == "__main__": logger.info(f"Using {options.judge_engine} as the judge engine") # set up evaluation model - # NOTE currenty we always stick to Llama (Llama or LlamaChat depending on evolution engine) as evaluation engine - # TODO allow to set separate engine and model for evaluation? - if isinstance(evolution_model, (Llama, LlamaChat, HfChat)): - evaluation_model_name = evolution_model.__class__.__name__.lower() - elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)): - evaluation_model_name = judge_model.__class__.__name__.lower() + if options.evaluation_engine is not None: + evaluation_model_name = options.evaluation_engine else: - evaluation_model_name = "llamachat" + if isinstance(evolution_model, (Llama, LlamaChat, HfChat)): + evaluation_model_name = evolution_model.__class__.__name__.lower() + elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)): + evaluation_model_name = judge_model.__class__.__name__.lower() + else: + evaluation_model_name = "llamachat" evaluation_model = LLMModel.get_model(name=evaluation_model_name, **vars(options)) logger.info(f"Using {evaluation_model_name} as the evaluation engine") -- GitLab