diff --git a/eval_prompt.py b/eval_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cab747e890145ad4e8b4179e796f5a5c135b36 --- /dev/null +++ b/eval_prompt.py @@ -0,0 +1,53 @@ +from argparse import Namespace +import argparse +import logging +from evoprompt.models import LLMModel +from evoprompt.task import get_task, tasks + +from evoprompt.utils import setup_console_logger + +logger = logging.getLogger(__name__) + + +def evaluate_prompt(prompt: str, **kwargs): + logger.info(f'Evaluating prompt "{prompt}"') + + evaluation_model = LLMModel.get_model(name="llamachat", options=options) + + task = get_task( + "sst5", + evaluation_model, + **kwargs, + ) + + eval_score, eval_usage, _ = task.evaluate_validation(prompt) + logger.info(f"Score on evaluation set: {eval_score}") + test_score, test_usage, _ = task.evaluate_test(prompt) + logger.info(f"Score on test set: {test_score}") + + +if __name__ == "__main__": + setup_console_logger() + + argparser = argparse.ArgumentParser() + argparser.add_argument("-p", "--prompt", type=str, required=True) + argparser.add_argument( + "-t", "--task", type=str, choices=sorted(tasks.keys()), required=True + ) + args = argparser.parse_args() + + options = Namespace( + llama_path=None, + llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF", + llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf", + chat_format=None, + chat_handler=None, + verbose=False, + llama_verbose=False, + disable_cache=False, + use_grammar=False, + evaluation_strategy="simple", + max_tokens=None, + ) + + evaluate_prompt(**vars(args), **vars(options)) diff --git a/evoprompt/models.py b/evoprompt/models.py index 8f717a16e214b742b466a126d7c77f13fc59e055..d1ced017baa0c0dca8e9dc06f8fa182d786292da 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -153,6 +153,8 @@ class Llama(LLMModel): if seed is not None: add_kwargs["seed"] = seed + # TODO some options could be optional + if options.llama_path is not None: # use local file self.model = llama_cpp.Llama( diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py index aa750c77d17bfaddfd53f4651ce92b3906a12e76..5238a2cf8d9a5cf231b4de1d3b18fd161b3af55b 100644 --- a/evoprompt/task/__init__.py +++ b/evoprompt/task/__init__.py @@ -22,7 +22,7 @@ tasks = { } -def get_task(name: str, evaluation_model: LLMModel, **options): +def get_task(name: str, evaluation_model: LLMModel, **options) -> Task: if name not in tasks: raise ValueError("Model %s does not exist", name) return tasks[name](evaluation_model, **options) @@ -31,7 +31,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options): argument_parser.add_argument("--debug", "-d", action="store_true", default=None) argument_group = argument_parser.add_argument_group("Task arguments") argument_group.add_argument( - "--task", type=str, required=True, choices=sorted(tasks.keys()) + "--task", type=str, required=True, choices=sorted(tasks.keys()) ) argument_group.add_argument("--use-grammar", "-g", action="store_true") argument_group.add_argument( diff --git a/evoprompt/utils.py b/evoprompt/utils.py index 03c95a4c134d3a8f71ea959ff1e2e4ce35e02f73..4ca878d4b5064d048370abea0bc359289557d6a5 100644 --- a/evoprompt/utils.py +++ b/evoprompt/utils.py @@ -7,7 +7,7 @@ from functools import wraps from pathlib import Path from pprint import pformat from textwrap import dedent, indent -from typing import Any, Callable +from typing import Any, Callable, Type, TypeVar from uuid import uuid4 import numpy @@ -177,7 +177,10 @@ class log_calls: return arguments -def get_all_subclasses(cls): +T = TypeVar("T") + + +def get_all_subclasses(cls: Type[T]) -> set[T]: return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] )