Skip to content
Snippets Groups Projects
Commit 5218f36a authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

let the llm create the run name :man_shrugging:

parent 2f3cbc02
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ from tqdm import trange
from cli import argument_parser
from models import Llama2, OpenAI
from task import QuestionAnswering, SentimentAnalysis
from utils import Prompt, log_calls, logger, save_snapshot
from utils import Prompt, initialize_run_directory, log_calls, logger, save_snapshot
def conv2bool(_str: Any):
......@@ -250,6 +250,7 @@ def run_episode(evo_alg_str: str, debug: bool = False):
P.append([prompt.id for prompt in population])
save_snapshot(
run_directory,
all_prompts,
family_tree,
P,
......@@ -311,6 +312,8 @@ if __name__ == "__main__":
chat=options.chat,
)
run_directory = initialize_run_directory(evolution_model)
match options.task:
case "sa":
logger.info("Running with task sentiment analysis on dataset SetFit/sst2")
......
......@@ -2,7 +2,6 @@ import inspect
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from pathlib import Path
from pprint import pformat
......@@ -13,16 +12,25 @@ from uuid import uuid4
from models import Llama2, OpenAI
current_directory = Path(__file__).resolve().parent
run_directory = (
current_directory / f"runs/run-{datetime.now().strftime('%Y-%m-%d %H-%M-%S')}"
)
run_directory.mkdir(parents=True, exist_ok=False)
logger = logging.getLogger("test-classifier")
logger.setLevel(level=logging.DEBUG)
file_handler = logging.FileHandler(run_directory / "output.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
logger.addHandler(file_handler)
run_name_prompt = """Create a random experiemnt name consisting of only a first and last name. The name should sound german or dutch. The parts should be separated by underscores and contain only lowercase. </prompt>.
<prompt>"""
def initialize_run_directory(model: OpenAI | Llama2):
run_name = model(run_name_prompt)
run_directory = current_directory / f"runs/run-{run_name}"
run_directory.mkdir(parents=True, exist_ok=False)
file_handler = logging.FileHandler(run_directory / "output.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"initialized run directory at {run_directory}")
return run_directory
class log_calls:
......@@ -122,6 +130,7 @@ class PromptEncoder(json.JSONEncoder):
def save_snapshot(
run_directory: Path,
all_prompts: list[Prompt],
family_tree: dict[str, tuple[str, str] | None],
P: list[list[str]],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment