Skip to content
Snippets Groups Projects
Commit 580b49d7 authored by Max Kimmich's avatar Max Kimmich
Browse files

Allow to set random seed to allow for reproducible experiments

parent d3db25c6
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ import logging
from abc import ABCMeta, abstractmethod
from typing import Any
from numpy.random import choice
from evoprompt.utils import get_rng
from tqdm import trange
from evoprompt.cli import argument_parser
......@@ -27,7 +27,7 @@ Prompt 2: {prompt2}
DE_COT_PROMPTS = [
"Step 1: Identify the different parts between the Prompt 1 and Prompt 2:\nPrompt 1: {prompt1}\nPrompt 2: {prompt2}",
"Step 1: Identify the main different parts between the Prompt 1 and Prompt 2:\nPrompt 1: {prompt1}\nPrompt 2: {prompt2}",
"Step 2: Randomly mutate the different parts",
"Step 3: Combine the different parts with Prompt 3, selectively replace it with the different parts in Step 2 and generate a new prompt.\nPrompt 3: {prompt3}",
"Step 4: Cross over the prompt in the Step 3 with the following basic prompt and generate a final prompt bracketed with <prompt> and </prompt>:\nBasic Prompt: {basic_prompt}",
......@@ -71,7 +71,9 @@ class EvolutionAlgorithm(PromptOptimization, metaclass=ABCMeta):
# add small value to avoid zero chance of selection for some prompts
scores = [prompt.score + 1e-6 for prompt in prompts]
selection_probabilities = [score / sum(scores) for score in scores]
return choice(prompts, size=2, replace=False, p=selection_probabilities)
return get_rng().choice(
prompts, size=2, replace=False, p=selection_probabilities
)
@abstractmethod
def evolve(
......@@ -289,51 +291,38 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
prompts_current_evolution: list[Prompt],
current_iteration: int,
):
# TODO add description from paper
# TODO add description
# DE needs best prompt for evolution
best_prompt_current_evolution = max(
prompts_current_evolution, key=lambda prompt: prompt.score
)
response, messages, usage = self.evolution_model(
system_message=SYSTEM_MESSAGE,
prompt=DE_COT_PROMPTS[0].format(
prompt1=prompt_1,
prompt2=prompt_2,
),
)
# input(messages)
# input(response)
response, messages, usage = self.evolution_model(
system_message=SYSTEM_MESSAGE,
prompt=DE_COT_PROMPTS[1],
history=messages,
)
# input(messages)
# input(response)
response, messages, usage = self.evolution_model(
system_message=SYSTEM_MESSAGE,
prompt=DE_COT_PROMPTS[2].format(
prompt3=best_prompt_current_evolution,
basic_prompt=prompts_current_evolution[current_iteration],
),
history=messages,
)
# input(messages)
# input(response)
evolved_prompt, messages, usage = self.evolution_model(
system_message=SYSTEM_MESSAGE,
prompt=DE_COT_PROMPTS[3].format(
basic_prompt=prompts_current_evolution[current_iteration],
),
history=messages,
stop="</prompt>",
)
# input(messages)
# input(evolved_prompt)
if "<prompt>" in evolved_prompt:
evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
messages = None
for idx, prompt in enumerate(DE_COT_PROMPTS):
response, messages, usage = self.evolution_model(
system_message=SYSTEM_MESSAGE,
prompt=prompt.format(
prompt1=prompt_1,
prompt2=prompt_2,
prompt3=best_prompt_current_evolution,
basic_prompt=prompts_current_evolution[current_iteration],
),
history=messages,
stop="</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None,
)
logger.debug(
"Performed evolution (step %d) using DE-CoT:\n\tInputs: %s\n\tResponse: %s",
idx,
messages,
response,
)
# input(messages)
# input(response)
# at this point we should get a new prompt
if "<prompt>" in response:
response = response.split("<prompt>")[1].split("</prompt>")[0]
logger.info(
"DE-CoT-evolved prompts '%s', '%s' and '%s' with basic prompt '%s' into '%s'",
......@@ -341,10 +330,10 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
prompt_2,
best_prompt_current_evolution,
prompts_current_evolution[current_iteration],
evolved_prompt,
response,
)
return evolved_prompt, usage
return response, usage
def get_all_subclasses(cls):
......
......@@ -11,6 +11,7 @@ import openai
from evoprompt.cli import argument_parser
from evoprompt.opt_types import ModelUsage
from evoprompt.utils import get_seed
logger = logging.getLogger(__name__)
......@@ -79,6 +80,9 @@ class Llama(LLMModel):
super().__init__(options)
# initialize model
seed = get_seed()
if seed is not None:
kwargs["seed"] = seed
if options.llama_path is not None:
# use local file
self.model = llama_cpp.Llama(
......
......@@ -155,7 +155,7 @@ class PromptOptimization:
)
logger.info(
"Prompt '%s' scores %.2f %s.",
prompt.replace("\r", "\\r").replace("\n", "\\n"),
prompt,
score,
self.task.metric_name,
)
......
......@@ -300,7 +300,7 @@ class Task(metaclass=ABCMeta):
self.test_dataset = self.load_test_set(test_dataset, test_split)
if self.debug and len(self.test_dataset) > 5:
self.test_dataset = self.test_dataset.select(range(5))
self.test_dataset = self.test_dataset.shuffle(42).select(range(5))
# cache evaluation runs to disk in a specific shelf for each model
cache_path = Path(".cache_dir") / self.model.cache_key
......
......@@ -9,6 +9,12 @@ from textwrap import dedent, indent
from typing import Any, Callable
from uuid import uuid4
import numpy
import numpy.typing
rng = None
global_seed = None
current_directory = Path(__file__).resolve().parent
logger = logging.getLogger(__name__)
......@@ -16,6 +22,26 @@ logger = logging.getLogger(__name__)
file_handler = None
def init_rng(seed: int | None):
global rng, global_seed
global_seed = seed
rng = numpy.random.default_rng(seed)
logger.info("Initialized rng with seed %s.", seed)
def get_seed() -> int:
return global_seed
def get_rng() -> numpy.random.Generator:
if rng is None:
raise RuntimeError(
"Rng was not initialized. Make sure to run `%s` first."
% (init_rng.__module__ + "." + init_rng.__qualname__)
)
return rng
def setup_console_logger(verbosity_level: int = 0):
# create console handler
console_handler = logging.StreamHandler()
......@@ -61,7 +87,7 @@ def initialize_run_directory(model: Callable):
# create file handler and set level to debug
file_handler = logging.FileHandler(run_directory / "output.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
formatter = logging.Formatter("{asctime} {levelname} {message}", style="{")
file_handler.setFormatter(formatter)
# add file handler to logger
logger.addHandler(file_handler)
......@@ -70,6 +96,10 @@ def initialize_run_directory(model: Callable):
return run_directory
def escape_whitespace(text: str) -> str:
return text.replace("\r", "\\r").replace("\n", "\\n")
class log_calls:
description: str
......@@ -141,7 +171,8 @@ class log_calls:
arguments[argument_name] = value
return arguments
def get_all_subclasses(cls):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
)
\ No newline at end of file
)
......@@ -8,7 +8,7 @@ from evoprompt.cli import argument_parser
from evoprompt.evolution import get_optimizer_class
from evoprompt.models import Llama, LLMModel
from evoprompt.task import get_task
from evoprompt.utils import setup_console_logger
from evoprompt.utils import init_rng, setup_console_logger
logger = logging.getLogger(__name__)
......@@ -28,12 +28,20 @@ def conv2bool(_str: Any):
if __name__ == "__main__":
# set additional CLI run arguments
argument_parser.add_argument(
"-v", "--verbose", action="count", default=0, help="Increase "
"-v", "--verbose", action="count", default=0, help="Increase verbosity"
)
argument_parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="Set seed for random number generator (rng). ",
)
options = argument_parser.parse_args()
# set up console logging
# set up console logging and rnd
setup_console_logger(verbosity_level=options.verbose)
init_rng(options.seed)
# log cli arguments
logger.info(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment