From 2780fdcba0b660375c7281e5ab1f8f7bcbdce55d Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Tue, 1 Oct 2024 18:20:33 +0200
Subject: [PATCH] Update evaluate prompt script

---
 eval_prompt.py | 49 +++++++++++++++++++++++++++----------------------
 1 file changed, 27 insertions(+), 22 deletions(-)

diff --git a/eval_prompt.py b/eval_prompt.py
index 06a6375..7bab0cf 100644
--- a/eval_prompt.py
+++ b/eval_prompt.py
@@ -11,27 +11,14 @@ from evoprompt.utils import init_rng, setup_console_logger
 logger = logging.getLogger(__name__)
 
 
-def evaluate_prompt(prompt: str, task_name: str, args: Namespace):
+def evaluate_prompt(prompt: str, task_args: Namespace, model_args: Namespace):
     logger.info(f'Evaluating prompt "{prompt}"')
 
     evaluation_model = LLMModel.get_model(
-        # name="hfchat",
-        # name="alpacahfchat",
-        name="llamachat",
-        # model="PKU-Alignment/alpaca-7b-reproduced",
-        # model="chavinlo/alpaca-native",
-        load_in_8bit=True,
-        torch_dtype=torch.float16,
-        device_map="auto",
-        # use less randomness, i.e., more certain outputs
-        do_sample=False,
-        # temperature=0.0,
-        # torch_dtype is not JSON serializable therefore we ignore it
-        ignore_cache_kwargs=["torch_dtype"],
-        **vars(args),
+        **vars(model_args),
     )
 
-    task = get_task(task_name, evaluation_model, **vars(args))
+    task = get_task(evaluation_model=evaluation_model, **vars(task_args))
 
     eval_score, eval_usage, _ = task.evaluate_validation(prompt)
     logger.info(f"Score on evaluation set: {eval_score}")
@@ -55,7 +42,27 @@ if __name__ == "__main__":
         init_rng(1)
         setup_console_logger(verbosity_level=args.verbose)
 
-        options = Namespace(
+        task_options = Namespace(
+            name=args.task,
+            use_grammar=False,
+            evaluation_strategy="simple",
+            n_evaluation_demo=1,
+        )
+        model_options = Namespace(
+            # name="hfchat",
+            name="alpacahfchat",
+            # name="llamachat",
+            verbose=args.verbose,
+            # model="PKU-Alignment/alpaca-7b-reproduced",
+            # model="chavinlo/alpaca-native",
+            load_in_8bit=True,
+            torch_dtype=torch.float16,
+            device_map="auto",
+            # use less randomness, i.e., more certain outputs
+            enforce_randomness=False,
+            # temperature=0.0,
+            # torch_dtype is not JSON serializable therefore we ignore it
+            ignore_cache_kwargs=["torch_dtype"],
             llama_path=None,
             chat_format=None,
             chat_handler=None,
@@ -65,13 +72,11 @@ if __name__ == "__main__":
             llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf",
             # llama_model_file="llama-2-70b-chat.Q4_K_M.gguf",
             disable_cache=False,
-            use_grammar=False,
-            evaluation_strategy="simple",
             max_tokens=None,
-            n_evaluation_demo=1,
-            **vars(args),
         )
+        print(task_options)
+        print(model_options)
 
-        evaluate_prompt(args.prompt, args.task, options)
+        evaluate_prompt(args.prompt, task_options, model_options)
 
     main()
-- 
GitLab