import abc
import inspect
from abc import ABC, abstractmethod, abstractproperty
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, ClassVar

import openai
from llama_cpp import Llama

from cli import argument_parser
from opt_types import ModelUsage

current_directory = Path(__file__).resolve().parent


class LLMModel(ABC):
    models: ClassVar[dict[str, type["LLMModel"]]] = {}
    chat: bool

    def __init_subclass__(cls) -> None:
        if inspect.isabstract(cls):
            return
        cls.models[cls.__name__.lower()] = cls
        cls.register_arguments(argument_parser)

    @classmethod
    def get_model(cls, name: str, options: Namespace):
        if name not in cls.models:
            raise ValueError("Model %s does not exist", name)
        return cls.models[name](options)

    def __init__(self, options: Namespace):
        self.usage = ModelUsage()
        self.chat = options.chat

    @abstractmethod
    def __call__(
        self,
        system_message: str | None,
        prompt: str,
        *,
        prompt_appendix: str,
        prompt_prefix: str,
        prompt_suffix: str,
        chat: bool | None,
        stop: str,
        max_tokens: int,
        **kwargs: Any,
    ) -> Any:
        pass

    @classmethod
    @abstractmethod
    def register_arguments(cls, parser: ArgumentParser):
        pass


class LlamaModel(LLMModel):

    def __init__(
        self,
        options: Namespace,
        n_gpu_layers: int = 60,
        n_threads: int = 8,
        n_ctx: int = 4096,
        verbose: bool = False,
        **kwargs,
    ) -> None:

        # initialize model
        self.model = Llama(
            options.llama_path,
            chat_format=self.chat_format,
            verbose=verbose,
            n_gpu_layers=n_gpu_layers,
            n_threads=n_threads,
            n_ctx=n_ctx,
            **kwargs,
        )

        super().__init__(options)

    def __call__(
        self,
        system_message: str | None,
        prompt: str,
        *,
        prompt_appendix: str = "",
        prompt_prefix: str = "",
        prompt_suffix: str = "",
        chat: bool | None = None,
        stop: str = "</prompt>",
        max_tokens: int = 200,
        **kwargs: Any,
    ) -> tuple[str, ModelUsage]:
        if chat is None:
            chat = self.chat

        if chat:
            messages = [
                {
                    "role": "user",
                    "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix,
                }
            ]
            if system_message:
                messages.insert(
                    0,
                    {
                        "role": "system",
                        "content": system_message,
                    },
                )
            response = self.model.create_chat_completion(
                messages=messages,
                stop=stop,
                max_tokens=max_tokens,
                **kwargs,
            )
            response_text = response["choices"][0]["message"]["content"]
        else:
            response = self.model.create_completion(
                prompt=(system_message if system_message else "")
                + prompt_prefix
                + prompt
                + prompt_suffix
                + prompt_appendix,
                stop=stop,
                max_tokens=max_tokens,
                **kwargs,
            )
            response_text = response["choices"][0]["text"]
        # input(f"Response: {response_text}")
        usage = ModelUsage(**response["usage"])
        self.usage += usage
        return response_text, usage

    @property
    @abstractmethod
    def chat_format(self) -> str:
        pass

    @classmethod
    def register_arguments(cls, parser: ArgumentParser):
        group = parser.add_argument_group(f"{cls.__name__} model arguments")
        group.add_argument(
            "--llama-path", default="models/llama-2-13b-chat.Q5_K_M.gguf"
        )


class Llama2(LlamaModel):
    @property
    def chat_format(self) -> str:
        return "llama-2"


class Llama3(LlamaModel):
    @property
    def chat_format(self) -> str:
        return "llama-3"


class OpenAI(LLMModel):
    """Queries an OpenAI model using its API."""

    def __init__(
        self,
        options: Namespace,
        verbose: bool = False,
        **kwargs,
    ) -> None:
        self.model_name = options.openai_model
        super().__init__(options)

        # initialize client for API calls
        self.openai_client = openai.OpenAI(**kwargs)

    def __call__(
        self,
        system_message: str | None,
        prompt: str,
        *,
        prompt_appendix: str = "",
        prompt_prefix: str = "",
        prompt_suffix: str = "",
        chat: bool | None = None,
        stop: str = "</prompt>",
        max_tokens: int = 200,
        **kwargs: Any,
    ) -> tuple[str, ModelUsage]:
        if chat is None:
            chat = self.chat

        if chat:
            messages = [
                {
                    "role": "user",
                    "content": prompt_prefix + prompt + prompt_suffix + prompt_appendix,
                }
            ]
            if system_message:
                messages.insert(
                    0,
                    {
                        "role": "system",
                        "content": system_message,
                    },
                )
            response = self.openai_client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                stop=stop,
                max_tokens=max_tokens,
                **kwargs,
            )
            usage = ModelUsage(**response.usage.__dict__)
            self.usage += usage
            return response.choices[0].message.content, usage
        else:
            response = self.openai_client.completions.create(
                model=self.model,
                prompt=(system_message if system_message else "")
                + prompt_prefix
                + prompt
                + prompt_suffix
                + prompt_appendix,
                stop=stop,
                max_tokens=max_tokens,
                **kwargs,
            )
            usage = ModelUsage(**response.usage.__dict__)
            self.usage += usage
            return response.choices[0].text, usage

    @classmethod
    def register_arguments(cls, parser: ArgumentParser):
        group = parser.add_argument_group("OpenAI model arguments")
        group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")


argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
    "--evolution-engine",
    "-e",
    type=str,
    choices=LLMModel.models.keys(),
    default="llama2",
)
argument_group.add_argument("--chat", "-c", action="store_true")