Skip to content
Snippets Groups Projects
Commit eaa1f0d5 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

add extra metadata to snapshot

parent 4f666b9f
No related branches found
No related tags found
No related merge requests found
......@@ -135,7 +135,7 @@ def update(prompts: list[str], N: int):
# Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts,
# which are combined with the current population of N prompts.
# The updated population is then selected by retaining the N prompts with the highest scores.
retained_prompts = []
retained_prompts: list[Prompt] = []
min_retained_score = 0
for prompt in prompts:
if len(retained_prompts) < N:
......@@ -202,7 +202,6 @@ def run_episode(evo_alg_str: str, debug: bool = False):
)
new_evolutions = []
new_evolutions_scores = []
for i in trange(N, desc="N", leave=False):
# for both GA and DE we start with two parent prompts
......@@ -253,7 +252,7 @@ def run_episode(evo_alg_str: str, debug: bool = False):
# store new generation
P.append([prompt.id for prompt in population])
save_snapshot(all_prompts, family_tree, P, T, N)
save_snapshot(all_prompts, family_tree, P, T, N, task, evolution_model)
# Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D)
p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)]
......
from abc import abstractmethod
from pathlib import Path
from typing import Any
from llama_cpp import Llama
import openai
from utils import log_calls
from llama_cpp import Llama
current_directory = Path(__file__).resolve().parent
......@@ -39,11 +35,10 @@ class Llama2:
**kwargs,
)
# @log_calls("Running Llama model")
def __call__(
self,
prompt: str,
chat: bool = None,
chat: bool | None = None,
stop: str = "</prompt>",
max_tokens: int = 200,
**kwargs: Any
......
......@@ -10,7 +10,7 @@ from textwrap import dedent, indent
from typing import Any, Callable
from uuid import uuid4
from graphviz import Digraph
from models import Llama2, OpenAI
current_directory = Path(__file__).resolve().parent
run_directory = (
......@@ -121,7 +121,15 @@ class PromptEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj)
def save_snapshot(all_prompts, family_tree, P, T, N):
def save_snapshot(
all_prompts: list[Prompt],
family_tree: dict[str, tuple[str, str] | None],
P: list[list[str]],
T: int,
N: int,
task,
model: Llama2 | OpenAI,
):
import json
with open(run_directory / "snapshot.json", "w") as f:
......@@ -132,6 +140,13 @@ def save_snapshot(all_prompts, family_tree, P, T, N):
"P": P,
"T": T,
"N": N,
"task": {
"name": task.__class__.__name__,
"validation_dataset": task.validation_dataset.info.dataset_name,
"test_dataset": task.test_dataset.info.dataset_name,
"metric": task.metric_name,
},
"model": {"name": model.__class__.__name__},
},
f,
indent=4,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment