import inspect
import json
import logging
import re
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from pprint import pformat
from textwrap import dedent, indent
from typing import Any, Callable
from uuid import uuid4

from models import Llama2, OpenAI

current_directory = Path(__file__).resolve().parent
logger = logging.getLogger("test-classifier")
logger.setLevel(level=logging.DEBUG)

run_name_prompt = """
Create a random name that sounds german or dutch
The parts should be separated by underscores and contain only lowercase.
Only return the name without any text before or after.""".strip()


def initialize_run_directory(model: OpenAI | Llama2):
    response = model(run_name_prompt, chat=True)
    run_name_match = re.search(r"^\w+$", response, re.MULTILINE)
    if run_name_match is None:
        run_name = uuid4().hex
    else:
        run_name = run_name_match.group(0)
    run_directory = current_directory / f"runs/{run_name}"
    run_directory.mkdir(parents=True, exist_ok=False)
    file_handler = logging.FileHandler(run_directory / "output.log")
    file_handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    logger.info(f"Hello my name is {run_name} and I live in {run_directory}")
    return run_directory


class log_calls:
    description: str

    prolog = dedent(
        """
    ----------------------------------------
    {description}
    {func_name}:
        \tArguments:
        {arguments}
        """
    )

    epilog = dedent(
        """\tResult:
        {result}
    ----------------------------------------
    """
    )

    def __init__(self, description: str, level: int = logging.DEBUG):
        self.description = description
        self.level = level

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            arguments = self._get_named_arguments(func, *args, **kwargs)
            logger.log(
                self.level,
                self.prolog.format(
                    description=self.description,
                    func_name=func.__name__,
                    arguments=indent(pformat(arguments), "\t"),
                ),
            )

            result = func(*args, **kwargs)

            logger.log(
                self.level,
                self.epilog.format(
                    result=indent(pformat(result), "\t"),
                ),
            )
            return result

        return wrapper

    def _get_named_arguments(self, func: Callable[..., Any], *args, **kwargs):
        signature = inspect.signature(func)
        arguments = {}
        for argument_index, (argument_name, argument) in enumerate(
            signature.parameters.items()
        ):
            if argument_index < len(args):
                # argument_name is from args
                value = args[argument_index]
            elif argument_name in kwargs:
                # argument_name is from kwargs
                value = kwargs[argument_name]
            else:
                # argument_name is from defaults
                value = argument.default

            arguments[argument_name] = value
        return arguments


@dataclass(frozen=True)
class Prompt:
    content: str
    score: float
    gen: int
    id: str = field(default_factory=lambda: uuid4().hex)
    meta: dict = field(default_factory=dict)

    def __str__(self) -> str:
        return self.content

    def __hash__(self) -> int:
        return (
            hash(self.content)
            + hash(self.score)
            + hash(self.gen)
            + hash(frozenset(self.meta.items()))
        )


class PromptEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Prompt):
            return obj.__dict__
        return json.JSONEncoder.default(self, obj)


def save_snapshot(
    run_directory: Path,
    all_prompts: list[Prompt],
    family_tree: dict[str, tuple[str, str] | None],
    P: list[list[str]],
    T: int,
    N: int,
    task,
    model: Llama2 | OpenAI,
    run_options: dict[str, Any],
):
    import json

    with open(run_directory / "snapshot.json", "w") as f:
        json.dump(
            {
                "all_prompts": all_prompts,
                "family_tree": family_tree,
                "P": P,
                "T": T,
                "N": N,
                "task": {
                    "name": task.__class__.__name__,
                    "validation_dataset": task.validation_dataset.info.dataset_name,
                    "test_dataset": task.test_dataset.info.dataset_name,
                    "metric": task.metric_name,
                    "use_grammar": task.use_grammar,
                },
                "model": {"name": model.__class__.__name__},
                "run_options": run_options,
            },
            f,
            indent=4,
            cls=PromptEncoder,
        )


def load_snapshot(path: Path):
    import json

    with path.open("r") as f:
        snapshot = json.load(f)
    return (
        snapshot["family_tree"],
        snapshot["P"],
        snapshot["S"],
        snapshot["T"],
        snapshot["N"],
    )