Skip to content
Snippets Groups Projects
Commit 5d6a9062 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

add alpaca prompt evaluation

parent 46433787
No related branches found
No related tags found
No related merge requests found
from argparse import Namespace
import argparse
import logging
from argparse import Namespace
from evoprompt.models import LLMModel
from evoprompt.task import get_task, tasks
from evoprompt.utils import setup_console_logger
logger = logging.getLogger(__name__)
def evaluate_prompt(prompt: str, **kwargs):
def evaluate_prompt(prompt: str, task_name: str, args: Namespace):
logger.info(f'Evaluating prompt "{prompt}"')
evaluation_model = LLMModel.get_model(name="llamachat", options=options)
task = get_task(
"sst5",
evaluation_model,
**kwargs,
# evaluation_model = LLMModel.get_model(name="llamachat", options=args)
evaluation_model = LLMModel.get_model(
name="hfllamachat", options=args, model="PKU-Alignment/alpaca-7b-reproduced"
)
task = get_task(task_name, evaluation_model, **vars(args))
eval_score, eval_usage, _ = task.evaluate_validation(prompt)
logger.info(f"Score on evaluation set: {eval_score}")
test_score, test_usage, _ = task.evaluate_test(prompt)
......@@ -38,16 +37,17 @@ if __name__ == "__main__":
options = Namespace(
llama_path=None,
llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF",
llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf",
chat_format=None,
chat_handler=None,
verbose=False,
llama_verbose=False,
llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF",
llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf",
disable_cache=False,
use_grammar=False,
evaluation_strategy="simple",
max_tokens=None,
**vars(args),
)
evaluate_prompt(**vars(args), **vars(options))
evaluate_prompt(args.prompt, args.task, options)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment