import argparse import logging from argparse import Namespace import torch from evoprompt.models import LLMModel from evoprompt.task import get_task, tasks from evoprompt.utils import init_rng, setup_console_logger logger = logging.getLogger(__name__) def evaluate_prompt(prompt: str, task_args: Namespace, model_args: Namespace): logger.info(f'Evaluating prompt "{prompt}"') evaluation_model = LLMModel.get_model( **vars(model_args), ) task = get_task(evaluation_model=evaluation_model, **vars(task_args)) eval_score, eval_usage, _ = task.evaluate_validation(prompt) logger.info(f"Score on evaluation set: {eval_score}") test_score, test_usage, _ = task.evaluate_test(prompt) logger.info(f"Score on test set: {test_score}") if __name__ == "__main__": def main(): argparser = argparse.ArgumentParser() argparser.add_argument("-p", "--prompt", type=str, required=True) argparser.add_argument( "-t", "--task", type=str, choices=sorted(tasks.keys()), required=True ) argparser.add_argument( "-v", "--verbose", action="count", default=0, help="Increase verbosity" ) args = argparser.parse_args() init_rng(1) setup_console_logger(verbosity_level=args.verbose) task_options = Namespace( name=args.task, use_grammar=False, evaluation_strategy="simple", n_evaluation_demo=1, ) model_options = Namespace( # name="hfchat", name="alpacahfchat", # name="llamachat", verbose=args.verbose, # model="PKU-Alignment/alpaca-7b-reproduced", # model="chavinlo/alpaca-native", load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", # temperature=0.0, # torch_dtype is not JSON serializable therefore we ignore it ignore_cache_kwargs=["torch_dtype"], llama_path=None, chat_format=None, chat_handler=None, llama_verbose=False, llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF", # llama_model="TheBloke/Llama-2-70B-Chat-GGUF", llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf", # llama_model_file="llama-2-70b-chat.Q4_K_M.gguf", disable_cache=False, max_tokens=None, ) print(task_options) print(model_options) evaluate_prompt(args.prompt, task_options, model_options) main()