import inspect import json import logging import re from dataclasses import dataclass, field from functools import wraps from pathlib import Path from pprint import pformat from textwrap import dedent, indent from typing import Any, Callable from uuid import uuid4 from models import Llama2, OpenAI current_directory = Path(__file__).resolve().parent logger = logging.getLogger("test-classifier") logger.setLevel(level=logging.DEBUG) run_name_prompt = """ Create a random name that sounds german or dutch The parts should be separated by underscores and contain only lowercase. Only return the name without any text before or after.""".strip() def initialize_run_directory(model: OpenAI | Llama2): response = model(run_name_prompt, chat=True) run_name_match = re.search(r"^\w+$", response, re.MULTILINE) if run_name_match is None: run_name = uuid4().hex else: run_name = run_name_match.group(0) run_directory = current_directory / f"runs/{run_name}" run_directory.mkdir(parents=True, exist_ok=False) file_handler = logging.FileHandler(run_directory / "output.log") file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.info(f"Hello my name is {run_name} and I live in {run_directory}") return run_directory class log_calls: description: str prolog = dedent( """ ---------------------------------------- {description} {func_name}: \tArguments: {arguments} """ ) epilog = dedent( """\tResult: {result} ---------------------------------------- """ ) def __init__(self, description: str, level: int = logging.DEBUG): self.description = description self.level = level def __call__(self, func): @wraps(func) def wrapper(*args, **kwargs): arguments = self._get_named_arguments(func, *args, **kwargs) logger.log( self.level, self.prolog.format( description=self.description, func_name=func.__name__, arguments=indent(pformat(arguments), "\t"), ), ) result = func(*args, **kwargs) logger.log( self.level, self.epilog.format( result=indent(pformat(result), "\t"), ), ) return result return wrapper def _get_named_arguments(self, func: Callable[..., Any], *args, **kwargs): signature = inspect.signature(func) arguments = {} for argument_index, (argument_name, argument) in enumerate( signature.parameters.items() ): if argument_index < len(args): # argument_name is from args value = args[argument_index] elif argument_name in kwargs: # argument_name is from kwargs value = kwargs[argument_name] else: # argument_name is from defaults value = argument.default arguments[argument_name] = value return arguments @dataclass(frozen=True) class Prompt: content: str score: float gen: int id: str = field(default_factory=lambda: uuid4().hex) meta: dict = field(default_factory=dict) def __str__(self) -> str: return self.content def __hash__(self) -> int: return ( hash(self.content) + hash(self.score) + hash(self.gen) + hash(frozenset(self.meta.items())) ) class PromptEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Prompt): return obj.__dict__ return json.JSONEncoder.default(self, obj) def save_snapshot( run_directory: Path, all_prompts: list[Prompt], family_tree: dict[str, tuple[str, str] | None], P: list[list[str]], T: int, N: int, task, model: Llama2 | OpenAI, run_options: dict[str, Any], ): import json with open(run_directory / "snapshot.json", "w") as f: json.dump( { "all_prompts": all_prompts, "family_tree": family_tree, "P": P, "T": T, "N": N, "task": { "name": task.__class__.__name__, "validation_dataset": task.validation_dataset.info.dataset_name, "test_dataset": task.test_dataset.info.dataset_name, "metric": task.metric_name, "use_grammar": task.use_grammar, }, "model": {"name": model.__class__.__name__}, "run_options": run_options, }, f, indent=4, cls=PromptEncoder, ) def load_snapshot(path: Path): import json with path.open("r") as f: snapshot = json.load(f) return ( snapshot["family_tree"], snapshot["P"], snapshot["S"], snapshot["T"], snapshot["N"], )