From 46433787bfe91246673433e02b9f297bf02d0ea1 Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 24 Sep 2024 14:39:33 +0200 Subject: [PATCH] Add script for evaluating prompt --- eval_prompt.py | 53 ++++++++++++++++++++++++++++++++++++++ evoprompt/models.py | 2 ++ evoprompt/task/__init__.py | 4 +-- evoprompt/utils.py | 7 +++-- 4 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 eval_prompt.py diff --git a/eval_prompt.py b/eval_prompt.py new file mode 100644 index 0000000..a1cab74 --- /dev/null +++ b/eval_prompt.py @@ -0,0 +1,53 @@ +from argparse import Namespace +import argparse +import logging +from evoprompt.models import LLMModel +from evoprompt.task import get_task, tasks + +from evoprompt.utils import setup_console_logger + +logger = logging.getLogger(__name__) + + +def evaluate_prompt(prompt: str, **kwargs): + logger.info(f'Evaluating prompt "{prompt}"') + + evaluation_model = LLMModel.get_model(name="llamachat", options=options) + + task = get_task( + "sst5", + evaluation_model, + **kwargs, + ) + + eval_score, eval_usage, _ = task.evaluate_validation(prompt) + logger.info(f"Score on evaluation set: {eval_score}") + test_score, test_usage, _ = task.evaluate_test(prompt) + logger.info(f"Score on test set: {test_score}") + + +if __name__ == "__main__": + setup_console_logger() + + argparser = argparse.ArgumentParser() + argparser.add_argument("-p", "--prompt", type=str, required=True) + argparser.add_argument( + "-t", "--task", type=str, choices=sorted(tasks.keys()), required=True + ) + args = argparser.parse_args() + + options = Namespace( + llama_path=None, + llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF", + llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf", + chat_format=None, + chat_handler=None, + verbose=False, + llama_verbose=False, + disable_cache=False, + use_grammar=False, + evaluation_strategy="simple", + max_tokens=None, + ) + + evaluate_prompt(**vars(args), **vars(options)) diff --git a/evoprompt/models.py b/evoprompt/models.py index 8f717a1..d1ced01 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -153,6 +153,8 @@ class Llama(LLMModel): if seed is not None: add_kwargs["seed"] = seed + # TODO some options could be optional + if options.llama_path is not None: # use local file self.model = llama_cpp.Llama( diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index aa750c7..5238a2c 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -22,7 +22,7 @@ tasks = { } -def get_task(name: str, evaluation_model: LLMModel, **options): +def get_task(name: str, evaluation_model: LLMModel, **options) -> Task: if name not in tasks: raise ValueError("Model %s does not exist", name) return tasks[name](evaluation_model, **options) @@ -31,7 +31,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument( - "--task", type=str, required=True, choices=sorted(tasks.keys()) + "--task", type=str, required=True, choices=sorted(tasks.keys()) ) argument_group.add_argument("--use-grammar", "-g", action="store_true") argument_group.add_argument( diff --git a/evoprompt/utils.py b/evoprompt/utils.py index 03c95a4..4ca878d 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -7,7 +7,7 @@ from functools import wraps from pathlib import Path from pprint import pformat from textwrap import dedent, indent -from typing import Any, Callable +from typing import Any, Callable, Type, TypeVar from uuid import uuid4 import numpy @@ -177,7 +177,10 @@ class log_calls: return arguments -def get_all_subclasses(cls): +T = TypeVar("T") + + +def get_all_subclasses(cls: Type[T]) -> set[T]: return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] ) -- GitLab