Skip to content
Snippets Groups Projects
Commit 6a634ce2 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

add run options to snapshot

parent b323f2e9
No related branches found
No related tags found
No related merge requests found
...@@ -249,7 +249,16 @@ def run_episode(evo_alg_str: str, debug: bool = False): ...@@ -249,7 +249,16 @@ def run_episode(evo_alg_str: str, debug: bool = False):
# store new generation # store new generation
P.append([prompt.id for prompt in population]) P.append([prompt.id for prompt in population])
save_snapshot(all_prompts, family_tree, P, T, N, task, evolution_model) save_snapshot(
all_prompts,
family_tree,
P,
T,
N,
task,
evolution_model,
options.__dict__,
)
# Line 8: Return the best prompt, p∗, among the final population PT : # Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D) # p∗ ← argmaxp∈PT f(p, D)
p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)] p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)]
......
...@@ -129,6 +129,7 @@ def save_snapshot( ...@@ -129,6 +129,7 @@ def save_snapshot(
N: int, N: int,
task, task,
model: Llama2 | OpenAI, model: Llama2 | OpenAI,
run_options: dict[str, Any],
): ):
import json import json
...@@ -147,6 +148,7 @@ def save_snapshot( ...@@ -147,6 +148,7 @@ def save_snapshot(
"metric": task.metric_name, "metric": task.metric_name,
}, },
"model": {"name": model.__class__.__name__}, "model": {"name": model.__class__.__name__},
"run_options": run_options,
}, },
f, f,
indent=4, indent=4,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment