import argparse
import logging
from argparse import Namespace

import torch

from evoprompt.models import LLMModel
from evoprompt.task import get_task, tasks
from evoprompt.utils import init_rng, setup_console_logger

logger = logging.getLogger(__name__)


def evaluate_prompt(prompt: str, task_args: Namespace, model_args: Namespace):
    logger.info(f'Evaluating prompt "{prompt}"')

    evaluation_model = LLMModel.get_model(
        **vars(model_args),
    )

    task = get_task(evaluation_model=evaluation_model, **vars(task_args))

    eval_score, eval_usage, _ = task.evaluate_validation(prompt)
    logger.info(f"Score on evaluation set: {eval_score}")
    test_score, test_usage, _ = task.evaluate_test(prompt)
    logger.info(f"Score on test set: {test_score}")


if __name__ == "__main__":

    def main():
        argparser = argparse.ArgumentParser()
        argparser.add_argument("-p", "--prompt", type=str, required=True)
        argparser.add_argument(
            "-t", "--task", type=str, choices=sorted(tasks.keys()), required=True
        )
        argparser.add_argument(
            "-v", "--verbose", action="count", default=0, help="Increase verbosity"
        )
        args = argparser.parse_args()

        init_rng(1)
        setup_console_logger(verbosity_level=args.verbose)

        task_options = Namespace(
            name=args.task,
            use_grammar=False,
            evaluation_strategy="simple",
            n_evaluation_demo=1,
        )
        model_options = Namespace(
            # name="hfchat",
            name="alpacahfchat",
            # name="llamachat",
            verbose=args.verbose,
            # model="PKU-Alignment/alpaca-7b-reproduced",
            # model="chavinlo/alpaca-native",
            load_in_8bit=True,
            torch_dtype=torch.float16,
            device_map="auto",
            # temperature=0.0,
            # torch_dtype is not JSON serializable therefore we ignore it
            ignore_cache_kwargs=["torch_dtype"],
            llama_path=None,
            chat_format=None,
            chat_handler=None,
            llama_verbose=False,
            llama_model="QuantFactory/Meta-Llama-3.1-8B-Instruct-GGUF",
            # llama_model="TheBloke/Llama-2-70B-Chat-GGUF",
            llama_model_file="Meta-Llama-3.1-8B-Instruct.Q8_0.gguf",
            # llama_model_file="llama-2-70b-chat.Q4_K_M.gguf",
            disable_cache=False,
            max_tokens=None,
        )
        print(task_options)
        print(model_options)

        evaluate_prompt(args.prompt, task_options, model_options)

    main()