From 5218f36ab0485b788de80cefdff1f961654c24c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Wed, 13 Mar 2024 19:48:31 +0100 Subject: [PATCH] =?UTF-8?q?let=20the=20llm=20create=20the=20run=20name=20?= =?UTF-8?q?=F0=9F=A4=B7=E2=80=8D=E2=99=82=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 5 ++++- utils.py | 27 ++++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 32633ae..60dd47c 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from tqdm import trange from cli import argument_parser from models import Llama2, OpenAI from task import QuestionAnswering, SentimentAnalysis -from utils import Prompt, log_calls, logger, save_snapshot +from utils import Prompt, initialize_run_directory, log_calls, logger, save_snapshot def conv2bool(_str: Any): @@ -250,6 +250,7 @@ def run_episode(evo_alg_str: str, debug: bool = False): P.append([prompt.id for prompt in population]) save_snapshot( + run_directory, all_prompts, family_tree, P, @@ -311,6 +312,8 @@ if __name__ == "__main__": chat=options.chat, ) + run_directory = initialize_run_directory(evolution_model) + match options.task: case "sa": logger.info("Running with task sentiment analysis on dataset SetFit/sst2") diff --git a/utils.py b/utils.py index 6dcd295..cbbd701 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,6 @@ import inspect import json import logging from dataclasses import dataclass, field -from datetime import datetime from functools import wraps from pathlib import Path from pprint import pformat @@ -13,16 +12,25 @@ from uuid import uuid4 from models import Llama2, OpenAI current_directory = Path(__file__).resolve().parent -run_directory = ( - current_directory / f"runs/run-{datetime.now().strftime('%Y-%m-%d %H-%M-%S')}" -) -run_directory.mkdir(parents=True, exist_ok=False) logger = logging.getLogger("test-classifier") logger.setLevel(level=logging.DEBUG) -file_handler = logging.FileHandler(run_directory / "output.log") -file_handler.setLevel(logging.DEBUG) -formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") -logger.addHandler(file_handler) + +run_name_prompt = """Create a random experiemnt name consisting of only a first and last name. The name should sound german or dutch. The parts should be separated by underscores and contain only lowercase. </prompt>. + <prompt>""" + + +def initialize_run_directory(model: OpenAI | Llama2): + run_name = model(run_name_prompt) + run_directory = current_directory / f"runs/run-{run_name}" + run_directory.mkdir(parents=True, exist_ok=False) + file_handler = logging.FileHandler(run_directory / "output.log") + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logger.info(f"initialized run directory at {run_directory}") + return run_directory class log_calls: @@ -122,6 +130,7 @@ class PromptEncoder(json.JSONEncoder): def save_snapshot( + run_directory: Path, all_prompts: list[Prompt], family_tree: dict[str, tuple[str, str] | None], P: list[list[str]], -- GitLab