Skip to content
Snippets Groups Projects
Commit 5218f36a authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

let the llm create the run name :man_shrugging:

parent 2f3cbc02
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ from tqdm import trange ...@@ -10,7 +10,7 @@ from tqdm import trange
from cli import argument_parser from cli import argument_parser
from models import Llama2, OpenAI from models import Llama2, OpenAI
from task import QuestionAnswering, SentimentAnalysis from task import QuestionAnswering, SentimentAnalysis
from utils import Prompt, log_calls, logger, save_snapshot from utils import Prompt, initialize_run_directory, log_calls, logger, save_snapshot
def conv2bool(_str: Any): def conv2bool(_str: Any):
...@@ -250,6 +250,7 @@ def run_episode(evo_alg_str: str, debug: bool = False): ...@@ -250,6 +250,7 @@ def run_episode(evo_alg_str: str, debug: bool = False):
P.append([prompt.id for prompt in population]) P.append([prompt.id for prompt in population])
save_snapshot( save_snapshot(
run_directory,
all_prompts, all_prompts,
family_tree, family_tree,
P, P,
...@@ -311,6 +312,8 @@ if __name__ == "__main__": ...@@ -311,6 +312,8 @@ if __name__ == "__main__":
chat=options.chat, chat=options.chat,
) )
run_directory = initialize_run_directory(evolution_model)
match options.task: match options.task:
case "sa": case "sa":
logger.info("Running with task sentiment analysis on dataset SetFit/sst2") logger.info("Running with task sentiment analysis on dataset SetFit/sst2")
......
...@@ -2,7 +2,6 @@ import inspect ...@@ -2,7 +2,6 @@ import inspect
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
...@@ -13,16 +12,25 @@ from uuid import uuid4 ...@@ -13,16 +12,25 @@ from uuid import uuid4
from models import Llama2, OpenAI from models import Llama2, OpenAI
current_directory = Path(__file__).resolve().parent current_directory = Path(__file__).resolve().parent
run_directory = (
current_directory / f"runs/run-{datetime.now().strftime('%Y-%m-%d %H-%M-%S')}"
)
run_directory.mkdir(parents=True, exist_ok=False)
logger = logging.getLogger("test-classifier") logger = logging.getLogger("test-classifier")
logger.setLevel(level=logging.DEBUG) logger.setLevel(level=logging.DEBUG)
file_handler = logging.FileHandler(run_directory / "output.log")
file_handler.setLevel(logging.DEBUG) run_name_prompt = """Create a random experiemnt name consisting of only a first and last name. The name should sound german or dutch. The parts should be separated by underscores and contain only lowercase. </prompt>.
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") <prompt>"""
logger.addHandler(file_handler)
def initialize_run_directory(model: OpenAI | Llama2):
run_name = model(run_name_prompt)
run_directory = current_directory / f"runs/run-{run_name}"
run_directory.mkdir(parents=True, exist_ok=False)
file_handler = logging.FileHandler(run_directory / "output.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"initialized run directory at {run_directory}")
return run_directory
class log_calls: class log_calls:
...@@ -122,6 +130,7 @@ class PromptEncoder(json.JSONEncoder): ...@@ -122,6 +130,7 @@ class PromptEncoder(json.JSONEncoder):
def save_snapshot( def save_snapshot(
run_directory: Path,
all_prompts: list[Prompt], all_prompts: list[Prompt],
family_tree: dict[str, tuple[str, str] | None], family_tree: dict[str, tuple[str, str] | None],
P: list[list[str]], P: list[list[str]],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment