From 46433787bfe91246673433e02b9f297bf02d0ea1 Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 24 Sep 2024 14:39:33 +0200
Subject: [PATCH] Add script for evaluating prompt

---
 eval_prompt.py             | 53 ++++++++++++++++++++++++++++++++++++++
 evoprompt/models.py        |  2 ++
 evoprompt/task/__init__.py |  4 +--
 evoprompt/utils.py         |  7 +++--
 4 files changed, 62 insertions(+), 4 deletions(-)
 create mode 100644 eval_prompt.py

diff --git a/eval_prompt.py b/eval_prompt.py
new file mode 100644
index 0000000..a1cab74
--- /dev/null
+++ b/eval_prompt.py
@@ -0,0 +1,53 @@
+from argparse import Namespace
+import argparse
+import logging
+from evoprompt.models import LLMModel
+from evoprompt.task import get_task, tasks
+
+from evoprompt.utils import setup_console_logger
+
+logger = logging.getLogger(__name__)
+
+
+def evaluate_prompt(prompt: str, **kwargs):
+    logger.info(f'Evaluating prompt "{prompt}"')
+
+    evaluation_model = LLMModel.get_model(name="llamachat", options=options)
+
+    task = get_task(
+        "sst5",
+        evaluation_model,
+        **kwargs,
+    )
+
+    eval_score, eval_usage, _ = task.evaluate_validation(prompt)
+    logger.info(f"Score on evaluation set: {eval_score}")
+    test_score, test_usage, _ = task.evaluate_test(prompt)
+    logger.info(f"Score on test set: {test_score}")
+
+
+if __name__ == "__main__":
+    setup_console_logger()
+
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument("-p", "--prompt", type=str, required=True)
+    argparser.add_argument(
+        "-t", "--task", type=str, choices=sorted(tasks.keys()), required=True
+    )
+    args = argparser.parse_args()
+
+    options = Namespace(
+        llama_path=None,
+        llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF",
+        llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf",
+        chat_format=None,
+        chat_handler=None,
+        verbose=False,
+        llama_verbose=False,
+        disable_cache=False,
+        use_grammar=False,
+        evaluation_strategy="simple",
+        max_tokens=None,
+    )
+
+    evaluate_prompt(**vars(args), **vars(options))
diff --git a/evoprompt/models.py b/evoprompt/models.py
index 8f717a1..d1ced01 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -153,6 +153,8 @@ class Llama(LLMModel):
         if seed is not None:
             add_kwargs["seed"] = seed
 
+        # TODO some options could be optional
+
         if options.llama_path is not None:
             # use local file
             self.model = llama_cpp.Llama(
diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py
index aa750c7..5238a2c 100644
--- a/evoprompt/task/__init__.py
+++ b/evoprompt/task/__init__.py
@@ -22,7 +22,7 @@ tasks = {
 }
 
 
-def get_task(name: str, evaluation_model: LLMModel, **options):
+def get_task(name: str, evaluation_model: LLMModel, **options) -> Task:
     if name not in tasks:
         raise ValueError("Model %s does not exist", name)
     return tasks[name](evaluation_model, **options)
@@ -31,7 +31,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options):
 argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
 argument_group = argument_parser.add_argument_group("Task arguments")
 argument_group.add_argument(
-    "--task",  type=str, required=True, choices=sorted(tasks.keys())
+    "--task", type=str, required=True, choices=sorted(tasks.keys())
 )
 argument_group.add_argument("--use-grammar", "-g", action="store_true")
 argument_group.add_argument(
diff --git a/evoprompt/utils.py b/evoprompt/utils.py
index 03c95a4..4ca878d 100644
--- a/evoprompt/utils.py
+++ b/evoprompt/utils.py
@@ -7,7 +7,7 @@ from functools import wraps
 from pathlib import Path
 from pprint import pformat
 from textwrap import dedent, indent
-from typing import Any, Callable
+from typing import Any, Callable, Type, TypeVar
 from uuid import uuid4
 
 import numpy
@@ -177,7 +177,10 @@ class log_calls:
         return arguments
 
 
-def get_all_subclasses(cls):
+T = TypeVar("T")
+
+
+def get_all_subclasses(cls: Type[T]) -> set[T]:
     return set(cls.__subclasses__()).union(
         [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
     )
-- 
GitLab