diff --git a/main.py b/main.py index 32633ae54813e95ed6abcb3aa4eafdacad37982e..60dd47c45202e3f999b0424f500d4ea8e60c2902 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from tqdm import trange from cli import argument_parser from models import Llama2, OpenAI from task import QuestionAnswering, SentimentAnalysis -from utils import Prompt, log_calls, logger, save_snapshot +from utils import Prompt, initialize_run_directory, log_calls, logger, save_snapshot def conv2bool(_str: Any): @@ -250,6 +250,7 @@ def run_episode(evo_alg_str: str, debug: bool = False): P.append([prompt.id for prompt in population]) save_snapshot( + run_directory, all_prompts, family_tree, P, @@ -311,6 +312,8 @@ if __name__ == "__main__": chat=options.chat, ) + run_directory = initialize_run_directory(evolution_model) + match options.task: case "sa": logger.info("Running with task sentiment analysis on dataset SetFit/sst2") diff --git a/utils.py b/utils.py index 6dcd295524f8b11f1a377c3e69a13a8c01a4f091..cbbd7012e2ffef5226342a4c0121a5390d2b4994 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,6 @@ import inspect import json import logging from dataclasses import dataclass, field -from datetime import datetime from functools import wraps from pathlib import Path from pprint import pformat @@ -13,16 +12,25 @@ from uuid import uuid4 from models import Llama2, OpenAI current_directory = Path(__file__).resolve().parent -run_directory = ( - current_directory / f"runs/run-{datetime.now().strftime('%Y-%m-%d %H-%M-%S')}" -) -run_directory.mkdir(parents=True, exist_ok=False) logger = logging.getLogger("test-classifier") logger.setLevel(level=logging.DEBUG) -file_handler = logging.FileHandler(run_directory / "output.log") -file_handler.setLevel(logging.DEBUG) -formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") -logger.addHandler(file_handler) + +run_name_prompt = """Create a random experiemnt name consisting of only a first and last name. The name should sound german or dutch. The parts should be separated by underscores and contain only lowercase. </prompt>. + <prompt>""" + + +def initialize_run_directory(model: OpenAI | Llama2): + run_name = model(run_name_prompt) + run_directory = current_directory / f"runs/run-{run_name}" + run_directory.mkdir(parents=True, exist_ok=False) + file_handler = logging.FileHandler(run_directory / "output.log") + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logger.info(f"initialized run directory at {run_directory}") + return run_directory class log_calls: @@ -122,6 +130,7 @@ class PromptEncoder(json.JSONEncoder): def save_snapshot( + run_directory: Path, all_prompts: list[Prompt], family_tree: dict[str, tuple[str, str] | None], P: list[list[str]],