diff --git a/main.py b/main.py index fe8904cf423600480af9b6885bcf006f47148dcc..063f45bfe6e8426fafc52535fde5f36bfe6504c1 100644 --- a/main.py +++ b/main.py @@ -135,7 +135,7 @@ def update(prompts: list[str], N: int): # Specifically, at each iteration, EVOPROMPT based on GA produces N new prompts, # which are combined with the current population of N prompts. # The updated population is then selected by retaining the N prompts with the highest scores. - retained_prompts = [] + retained_prompts: list[Prompt] = [] min_retained_score = 0 for prompt in prompts: if len(retained_prompts) < N: @@ -202,7 +202,6 @@ def run_episode(evo_alg_str: str, debug: bool = False): ) new_evolutions = [] - new_evolutions_scores = [] for i in trange(N, desc="N", leave=False): # for both GA and DE we start with two parent prompts @@ -253,7 +252,7 @@ def run_episode(evo_alg_str: str, debug: bool = False): # store new generation P.append([prompt.id for prompt in population]) - save_snapshot(all_prompts, family_tree, P, T, N) + save_snapshot(all_prompts, family_tree, P, T, N, task, evolution_model) # Line 8: Return the best prompt, p∗, among the final population PT : # p∗ ↠argmaxp∈PT f(p, D) p = all_prompts[max(P[-1], key=lambda prompt_id: all_prompts[prompt_id].score)] diff --git a/models.py b/models.py index a640c53e7d49ccd576451eec76ce2595450c604c..80e586806d745525c149133c2317806226b8dd6a 100644 --- a/models.py +++ b/models.py @@ -1,12 +1,8 @@ -from abc import abstractmethod from pathlib import Path from typing import Any -from llama_cpp import Llama import openai - -from utils import log_calls - +from llama_cpp import Llama current_directory = Path(__file__).resolve().parent @@ -39,11 +35,10 @@ class Llama2: **kwargs, ) - # @log_calls("Running Llama model") def __call__( self, prompt: str, - chat: bool = None, + chat: bool | None = None, stop: str = "</prompt>", max_tokens: int = 200, **kwargs: Any diff --git a/utils.py b/utils.py index 8ea4d661a664fd4626da7c8390b5828aa5f350c8..1039e496be48968d70a7e383ab5a9a9c0a45657a 100644 --- a/utils.py +++ b/utils.py @@ -10,7 +10,7 @@ from textwrap import dedent, indent from typing import Any, Callable from uuid import uuid4 -from graphviz import Digraph +from models import Llama2, OpenAI current_directory = Path(__file__).resolve().parent run_directory = ( @@ -121,7 +121,15 @@ class PromptEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) -def save_snapshot(all_prompts, family_tree, P, T, N): +def save_snapshot( + all_prompts: list[Prompt], + family_tree: dict[str, tuple[str, str] | None], + P: list[list[str]], + T: int, + N: int, + task, + model: Llama2 | OpenAI, +): import json with open(run_directory / "snapshot.json", "w") as f: @@ -132,6 +140,13 @@ def save_snapshot(all_prompts, family_tree, P, T, N): "P": P, "T": T, "N": N, + "task": { + "name": task.__class__.__name__, + "validation_dataset": task.validation_dataset.info.dataset_name, + "test_dataset": task.test_dataset.info.dataset_name, + "metric": task.metric_name, + }, + "model": {"name": model.__class__.__name__}, }, f, indent=4,