From 2780fdcba0b660375c7281e5ab1f8f7bcbdce55d Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Tue, 1 Oct 2024 18:20:33 +0200 Subject: [PATCH] Update evaluate prompt script --- eval_prompt.py | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/eval_prompt.py b/eval_prompt.py index 06a6375..7bab0cf 100644 --- a/eval_prompt.py +++ b/eval_prompt.py @@ -11,27 +11,14 @@ from evoprompt.utils import init_rng, setup_console_logger logger = logging.getLogger(__name__) -def evaluate_prompt(prompt: str, task_name: str, args: Namespace): +def evaluate_prompt(prompt: str, task_args: Namespace, model_args: Namespace): logger.info(f'Evaluating prompt "{prompt}"') evaluation_model = LLMModel.get_model( - # name="hfchat", - # name="alpacahfchat", - name="llamachat", - # model="PKU-Alignment/alpaca-7b-reproduced", - # model="chavinlo/alpaca-native", - load_in_8bit=True, - torch_dtype=torch.float16, - device_map="auto", - # use less randomness, i.e., more certain outputs - do_sample=False, - # temperature=0.0, - # torch_dtype is not JSON serializable therefore we ignore it - ignore_cache_kwargs=["torch_dtype"], - **vars(args), + **vars(model_args), ) - task = get_task(task_name, evaluation_model, **vars(args)) + task = get_task(evaluation_model=evaluation_model, **vars(task_args)) eval_score, eval_usage, _ = task.evaluate_validation(prompt) logger.info(f"Score on evaluation set: {eval_score}") @@ -55,7 +42,27 @@ if __name__ == "__main__": init_rng(1) setup_console_logger(verbosity_level=args.verbose) - options = Namespace( + task_options = Namespace( + name=args.task, + use_grammar=False, + evaluation_strategy="simple", + n_evaluation_demo=1, + ) + model_options = Namespace( + # name="hfchat", + name="alpacahfchat", + # name="llamachat", + verbose=args.verbose, + # model="PKU-Alignment/alpaca-7b-reproduced", + # model="chavinlo/alpaca-native", + load_in_8bit=True, + torch_dtype=torch.float16, + device_map="auto", + # use less randomness, i.e., more certain outputs + enforce_randomness=False, + # temperature=0.0, + # torch_dtype is not JSON serializable therefore we ignore it + ignore_cache_kwargs=["torch_dtype"], llama_path=None, chat_format=None, chat_handler=None, @@ -65,13 +72,11 @@ if __name__ == "__main__": llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf", # llama_model_file="llama-2-70b-chat.Q4_K_M.gguf", disable_cache=False, - use_grammar=False, - evaluation_strategy="simple", max_tokens=None, - n_evaluation_demo=1, - **vars(args), ) + print(task_options) + print(model_options) - evaluate_prompt(args.prompt, args.task, options) + evaluate_prompt(args.prompt, task_options, model_options) main() -- GitLab