Skip to content
Snippets Groups Projects
Commit eaa1f0d5 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

add extra metadata to snapshot

parent 4f666b9f
No related branches found
No related tags found
No related merge requests found
...@@ -135,7 +135,7 @@ def update(prompts: list[str], N: int): ...@@ -135,7 +135,7 @@ def update(prompts: list[str], N: int):
# Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts, # Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts,
# which are combined with the current population of N prompts. # which are combined with the current population of N prompts.
# The updated population is then selected by retaining the N prompts with the highest scores. # The updated population is then selected by retaining the N prompts with the highest scores.
retained_prompts = [] retained_prompts: list[Prompt] = []
min_retained_score = 0 min_retained_score = 0
for prompt in prompts: for prompt in prompts:
if len(retained_prompts) < N: if len(retained_prompts) < N:
...@@ -202,7 +202,6 @@ def run_episode(evo_alg_str: str, debug: bool = False): ...@@ -202,7 +202,6 @@ def run_episode(evo_alg_str: str, debug: bool = False):
) )
new_evolutions = [] new_evolutions = []
new_evolutions_scores = []
for i in trange(N, desc="N", leave=False): for i in trange(N, desc="N", leave=False):
# for both GA and DE we start with two parent prompts # for both GA and DE we start with two parent prompts
...@@ -253,7 +252,7 @@ def run_episode(evo_alg_str: str, debug: bool = False): ...@@ -253,7 +252,7 @@ def run_episode(evo_alg_str: str, debug: bool = False):
# store new generation # store new generation
P.append([prompt.id for prompt in population]) P.append([prompt.id for prompt in population])
save_snapshot(all_prompts, family_tree, P, T, N) save_snapshot(all_prompts, family_tree, P, T, N, task, evolution_model)
# Line 8: Return the best prompt, p∗, among the final population PT : # Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D) # p∗ ← argmaxp∈PT f(p, D)
p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)] p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)]
......
from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from llama_cpp import Llama
import openai import openai
from llama_cpp import Llama
from utils import log_calls
current_directory = Path(__file__).resolve().parent current_directory = Path(__file__).resolve().parent
...@@ -39,11 +35,10 @@ class Llama2: ...@@ -39,11 +35,10 @@ class Llama2:
**kwargs, **kwargs,
) )
# @log_calls("Running Llama model")
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
chat: bool = None, chat: bool | None = None,
stop: str = "</prompt>", stop: str = "</prompt>",
max_tokens: int = 200, max_tokens: int = 200,
**kwargs: Any **kwargs: Any
......
...@@ -10,7 +10,7 @@ from textwrap import dedent, indent ...@@ -10,7 +10,7 @@ from textwrap import dedent, indent
from typing import Any, Callable from typing import Any, Callable
from uuid import uuid4 from uuid import uuid4
from graphviz import Digraph from models import Llama2, OpenAI
current_directory = Path(__file__).resolve().parent current_directory = Path(__file__).resolve().parent
run_directory = ( run_directory = (
...@@ -121,7 +121,15 @@ class PromptEncoder(json.JSONEncoder): ...@@ -121,7 +121,15 @@ class PromptEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
def save_snapshot(all_prompts, family_tree, P, T, N): def save_snapshot(
all_prompts: list[Prompt],
family_tree: dict[str, tuple[str, str] | None],
P: list[list[str]],
T: int,
N: int,
task,
model: Llama2 | OpenAI,
):
import json import json
with open(run_directory / "snapshot.json", "w") as f: with open(run_directory / "snapshot.json", "w") as f:
...@@ -132,6 +140,13 @@ def save_snapshot(all_prompts, family_tree, P, T, N): ...@@ -132,6 +140,13 @@ def save_snapshot(all_prompts, family_tree, P, T, N):
"P": P, "P": P,
"T": T, "T": T,
"N": N, "N": N,
"task": {
"name": task.__class__.__name__,
"validation_dataset": task.validation_dataset.info.dataset_name,
"test_dataset": task.test_dataset.info.dataset_name,
"metric": task.metric_name,
},
"model": {"name": model.__class__.__name__},
}, },
f, f,
indent=4, indent=4,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment