From 5218f36ab0485b788de80cefdff1f961654c24c2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Wed, 13 Mar 2024 19:48:31 +0100
Subject: [PATCH] =?UTF-8?q?let=20the=20llm=20create=20the=20run=20name=20?=
 =?UTF-8?q?=F0=9F=A4=B7=E2=80=8D=E2=99=82=EF=B8=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 main.py  |  5 ++++-
 utils.py | 27 ++++++++++++++++++---------
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/main.py b/main.py
index 32633ae..60dd47c 100644
--- a/main.py
+++ b/main.py
@@ -10,7 +10,7 @@ from tqdm import trange
 from cli import argument_parser
 from models import Llama2, OpenAI
 from task import QuestionAnswering, SentimentAnalysis
-from utils import Prompt, log_calls, logger, save_snapshot
+from utils import Prompt, initialize_run_directory, log_calls, logger, save_snapshot
 
 
 def conv2bool(_str: Any):
@@ -250,6 +250,7 @@ def run_episode(evo_alg_str: str, debug: bool = False):
         P.append([prompt.id for prompt in population])
 
     save_snapshot(
+        run_directory,
         all_prompts,
         family_tree,
         P,
@@ -311,6 +312,8 @@ if __name__ == "__main__":
                 chat=options.chat,
             )
 
+    run_directory = initialize_run_directory(evolution_model)
+
     match options.task:
         case "sa":
             logger.info("Running with task sentiment analysis on dataset SetFit/sst2")
diff --git a/utils.py b/utils.py
index 6dcd295..cbbd701 100644
--- a/utils.py
+++ b/utils.py
@@ -2,7 +2,6 @@ import inspect
 import json
 import logging
 from dataclasses import dataclass, field
-from datetime import datetime
 from functools import wraps
 from pathlib import Path
 from pprint import pformat
@@ -13,16 +12,25 @@ from uuid import uuid4
 from models import Llama2, OpenAI
 
 current_directory = Path(__file__).resolve().parent
-run_directory = (
-    current_directory / f"runs/run-{datetime.now().strftime('%Y-%m-%d %H-%M-%S')}"
-)
-run_directory.mkdir(parents=True, exist_ok=False)
 logger = logging.getLogger("test-classifier")
 logger.setLevel(level=logging.DEBUG)
-file_handler = logging.FileHandler(run_directory / "output.log")
-file_handler.setLevel(logging.DEBUG)
-formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
-logger.addHandler(file_handler)
+
+run_name_prompt = """Create a random experiemnt name consisting of only a first and last name. The name should sound german or dutch. The parts should be separated by underscores and contain only lowercase. </prompt>.
+                    <prompt>"""
+
+
+def initialize_run_directory(model: OpenAI | Llama2):
+    run_name = model(run_name_prompt)
+    run_directory = current_directory / f"runs/run-{run_name}"
+    run_directory.mkdir(parents=True, exist_ok=False)
+    file_handler = logging.FileHandler(run_directory / "output.log")
+    file_handler.setLevel(logging.DEBUG)
+    formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
+    file_handler.setFormatter(formatter)
+    logger.addHandler(file_handler)
+
+    logger.info(f"initialized run directory at {run_directory}")
+    return run_directory
 
 
 class log_calls:
@@ -122,6 +130,7 @@ class PromptEncoder(json.JSONEncoder):
 
 
 def save_snapshot(
+    run_directory: Path,
     all_prompts: list[Prompt],
     family_tree: dict[str, tuple[str, str] | None],
     P: list[list[str]],
-- 
GitLab