Skip to content
Snippets Groups Projects
Commit 2780fdcb authored by Max Kimmich's avatar Max Kimmich
Browse files

Update evaluate prompt script

parent d82b104a
No related branches found
No related tags found
1 merge request!7Refactor tasks and models and fix format for various models
......@@ -11,27 +11,14 @@ from evoprompt.utils import init_rng, setup_console_logger
logger = logging.getLogger(__name__)
def evaluate_prompt(prompt: str, task_name: str, args: Namespace):
def evaluate_prompt(prompt: str, task_args: Namespace, model_args: Namespace):
logger.info(f'Evaluating prompt "{prompt}"')
evaluation_model = LLMModel.get_model(
# name="hfchat",
# name="alpacahfchat",
name="llamachat",
# model="PKU-Alignment/alpaca-7b-reproduced",
# model="chavinlo/alpaca-native",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
# use less randomness, i.e., more certain outputs
do_sample=False,
# temperature=0.0,
# torch_dtype is not JSON serializable therefore we ignore it
ignore_cache_kwargs=["torch_dtype"],
**vars(args),
**vars(model_args),
)
task = get_task(task_name, evaluation_model, **vars(args))
task = get_task(evaluation_model=evaluation_model, **vars(task_args))
eval_score, eval_usage, _ = task.evaluate_validation(prompt)
logger.info(f"Score on evaluation set: {eval_score}")
......@@ -55,7 +42,27 @@ if __name__ == "__main__":
init_rng(1)
setup_console_logger(verbosity_level=args.verbose)
options = Namespace(
task_options = Namespace(
name=args.task,
use_grammar=False,
evaluation_strategy="simple",
n_evaluation_demo=1,
)
model_options = Namespace(
# name="hfchat",
name="alpacahfchat",
# name="llamachat",
verbose=args.verbose,
# model="PKU-Alignment/alpaca-7b-reproduced",
# model="chavinlo/alpaca-native",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
# use less randomness, i.e., more certain outputs
enforce_randomness=False,
# temperature=0.0,
# torch_dtype is not JSON serializable therefore we ignore it
ignore_cache_kwargs=["torch_dtype"],
llama_path=None,
chat_format=None,
chat_handler=None,
......@@ -65,13 +72,11 @@ if __name__ == "__main__":
llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf",
# llama_model_file="llama-2-70b-chat.Q4_K_M.gguf",
disable_cache=False,
use_grammar=False,
evaluation_strategy="simple",
max_tokens=None,
n_evaluation_demo=1,
**vars(args),
)
print(task_options)
print(model_options)
evaluate_prompt(args.prompt, args.task, options)
evaluate_prompt(args.prompt, task_options, model_options)
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment