From 4fa8266f7e330c3c397920e4b83c8e3d31b509c9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Fri, 25 Oct 2024 18:11:48 +0200
Subject: [PATCH] allow setting evaluation engine via parameter

---
 evoprompt/models.py       |  8 +++++++-
 evoprompt/optimization.py |  8 +++++++-
 main.py                   | 15 ++++++++-------
 3 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 8360c9e..248bf94 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -564,6 +564,11 @@ class HfChat(ChatModel, LLMModel):
         else:
             model_call_kwargs["do_sample"] = False
 
+        if "max_tokens" in model_call_kwargs:
+            model_call_kwargs["max_completion_tokens"] = model_call_kwargs.pop(
+                "max_tokens"
+            )
+
         # input(
         #     f"The input for the model will look like this:\n'{self.pipeline.tokenizer.apply_chat_template(model_call_kwargs["text_inputs"], tokenize=False, add_generation_prompt=True)}'"
         # )
@@ -645,6 +650,7 @@ class OpenAIChat(ChatModel, LLMModel):
         self,
         messages: ChatMessages,
         *,
+        grammar: None = None,
         use_cache: bool,
         stop: str | None,
         max_tokens: int | None,
@@ -656,7 +662,7 @@ class OpenAIChat(ChatModel, LLMModel):
             "model": self.model_name,
             "messages": messages,
             "stop": stop,
-            "max_completion_tokens": max_tokens if max_tokens is not None else 1024,
+            "max_tokens": max_tokens if max_tokens is not None else 1024,
         }
         if use_randomness:
             if temperature is None:
diff --git a/evoprompt/optimization.py b/evoprompt/optimization.py
index cc3710a..782de32 100644
--- a/evoprompt/optimization.py
+++ b/evoprompt/optimization.py
@@ -6,7 +6,6 @@ from difflib import Differ
 from pathlib import Path
 from typing import Any, Optional
 
-import wandb
 from rich.panel import Panel
 from rich.rule import Rule
 from tabulate import tabulate
@@ -16,6 +15,7 @@ from textual.containers import ScrollableContainer
 from textual.widgets import Collapsible, Footer, Label, Static, TextArea
 from tqdm import tqdm, trange
 
+import wandb
 from evoprompt.cli import argument_parser
 from evoprompt.models import ChatMessages, LLMModel
 from evoprompt.opt_types import (
@@ -464,6 +464,12 @@ def load_snapshot(path: Path):
 
 
 argument_group = argument_parser.add_argument_group("Optimization arguments")
+argument_group.add_argument(
+    "--evaluation-engine",
+    type=str,
+    choices=LLMModel.registered_models.keys(),
+    required=False,
+)
 argument_group.add_argument(
     "--evolution-engine",
     type=str,
diff --git a/main.py b/main.py
index 66a16ac..338b1eb 100644
--- a/main.py
+++ b/main.py
@@ -102,14 +102,15 @@ if __name__ == "__main__":
         logger.info(f"Using {options.judge_engine} as the judge engine")
 
     # set up evaluation model
-    # NOTE currenty we always stick to Llama (Llama or LlamaChat depending on evolution engine) as evaluation engine
-    # TODO allow to set separate engine and model for evaluation?
-    if isinstance(evolution_model, (Llama, LlamaChat, HfChat)):
-        evaluation_model_name = evolution_model.__class__.__name__.lower()
-    elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)):
-        evaluation_model_name = judge_model.__class__.__name__.lower()
+    if options.evaluation_engine is not None:
+        evaluation_model_name = options.evaluation_engine
     else:
-        evaluation_model_name = "llamachat"
+        if isinstance(evolution_model, (Llama, LlamaChat, HfChat)):
+            evaluation_model_name = evolution_model.__class__.__name__.lower()
+        elif judge_model is not None and isinstance(judge_model, (Llama, LlamaChat)):
+            evaluation_model_name = judge_model.__class__.__name__.lower()
+        else:
+            evaluation_model_name = "llamachat"
     evaluation_model = LLMModel.get_model(name=evaluation_model_name, **vars(options))
     logger.info(f"Using {evaluation_model_name} as the evaluation engine")
 
-- 
GitLab