From eaa1f0d57ead79e1d3226969769a08a9ca5170e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Wed, 13 Mar 2024 07:30:02 +0100 Subject: [PATCH] add extra metadata to snapshot --- main.py | 5 ++--- models.py | 9 ++------- utils.py | 19 +++++++++++++++++-- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index fe8904c..063f45b 100644 --- a/main.py +++ b/main.py @@ -135,7 +135,7 @@ def update(prompts: list[str], N: int): # Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts, # which are combined with the current population of N prompts. # The updated population is then selected by retaining the N prompts with the highest scores. - retained_prompts = [] + retained_prompts: list[Prompt] = [] min_retained_score = 0 for prompt in prompts: if len(retained_prompts) < N: @@ -202,7 +202,6 @@ def run_episode(evo_alg_str: str, debug: bool = False): ) new_evolutions = [] - new_evolutions_scores = [] for i in trange(N, desc="N", leave=False): # for both GA and DE we start with two parent prompts @@ -253,7 +252,7 @@ def run_episode(evo_alg_str: str, debug: bool = False): # store new generation P.append([prompt.id for prompt in population]) - save_snapshot(all_prompts, family_tree, P, T, N) + save_snapshot(all_prompts, family_tree, P, T, N, task, evolution_model) # Line 8: Return the best prompt, p∗, among the final population PT : # p∗ ← argmaxp∈PT f(p, D) p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)] diff --git a/models.py b/models.py index a640c53..80e5868 100644 --- a/models.py +++ b/models.py @@ -1,12 +1,8 @@ -from abc import abstractmethod from pathlib import Path from typing import Any -from llama_cpp import Llama import openai - -from utils import log_calls - +from llama_cpp import Llama current_directory = Path(__file__).resolve().parent @@ -39,11 +35,10 @@ class Llama2: **kwargs, ) - # @log_calls("Running Llama model") def __call__( self, prompt: str, - chat: bool = None, + chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, **kwargs: Any diff --git a/utils.py b/utils.py index 8ea4d66..1039e49 100644 --- a/utils.py +++ b/utils.py @@ -10,7 +10,7 @@ from textwrap import dedent, indent from typing import Any, Callable from uuid import uuid4 -from graphviz import Digraph +from models import Llama2, OpenAI current_directory = Path(__file__).resolve().parent run_directory = ( @@ -121,7 +121,15 @@ class PromptEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) -def save_snapshot(all_prompts, family_tree, P, T, N): +def save_snapshot( + all_prompts: list[Prompt], + family_tree: dict[str, tuple[str, str] | None], + P: list[list[str]], + T: int, + N: int, + task, + model: Llama2 | OpenAI, +): import json with open(run_directory / "snapshot.json", "w") as f: @@ -132,6 +140,13 @@ def save_snapshot(all_prompts, family_tree, P, T, N): "P": P, "T": T, "N": N, + "task": { + "name": task.__class__.__name__, + "validation_dataset": task.validation_dataset.info.dataset_name, + "test_dataset": task.test_dataset.info.dataset_name, + "metric": task.metric_name, + }, + "model": {"name": model.__class__.__name__}, }, f, indent=4, -- GitLab