Skip to content
Snippets Groups Projects
Commit 9976307b authored by Max Kimmich's avatar Max Kimmich
Browse files

Add optimization backend for HTTP API

parent aa579769
No related branches found
No related tags found
No related merge requests found
from fastapi import FastAPI, Request from contextlib import asynccontextmanager
from fastapi import BackgroundTasks, FastAPI, Request
# from api.optimization import Backend from api.optimization import MultiProcessOptimizer
# see https://github.com/tiangolo/fastapi/issues/3091#issuecomment-821522932 and https://github.com/encode/starlette/issues/1094#issuecomment-730346075 for heavy-load computation
DEBUG = True DEBUG = True
backend = None
app = FastAPI(debug=DEBUG, title="Prompt Optimization Backend")
# api = Backend(debug=DEBUG) @asynccontextmanager
async def lifespan(app: FastAPI):
global backend
# Load the backend (which runs models in a separate process)
backend = MultiProcessOptimizer(debug=DEBUG)
with backend:
# add endpoints from backend
# TODO allow to get dynamically
actions = [("/action/evolve", backend.optimizer.evolve)]
for path, target in actions:
app.add_api_route(path, target)
app.openapi_schema = None
app.openapi()
def test(): yield
pass # Unload the backend freeing used ressources by the separate process
# automatically done when with block is exited
print("Releasing resources")
# release remaining allocations
del backend
# TODO somehow not all ressources are released upon uvicorn reload, need to investigate further..
# @app.get("/test") app = FastAPI(debug=DEBUG, title="Prompt Optimization Backend", lifespan=lifespan)
# async def test_long_operation(request: Request):
# loop = asyncio.get_event_loop()
# result = await loop.run_in_executor(pool, test)
# return "ok"
# start optimization # start optimization
@app.get("/run/{num_iterations}") @app.get("/run/start")
async def run(num_iterations: int) -> str: def run(num_iterations: int, background_tasks: BackgroundTasks) -> str:
# api.run_optimization(num_iterations) background_tasks.add_task(backend.run_optimization, num_iterations)
return "ok" return f"Running optimization with {num_iterations} iterations."
# get progress
@app.get("/run/progress")
def get_run_progress() -> str:
result = backend.get_progress()
return result
# TODO turn actions into router and allow to set actions dynamically # get run state
# perform optimizer-specific action @app.get("/run/status")
@app.get("/action/evolve/") async def get_run_status() -> bool:
async def evolve(prompt1: str, prompt2: str) -> str: return backend._running
return f"This is the evolved prompt taking prompts {prompt1} and {prompt2} into account."
# get current genealogy of prompts # get current genealogy of prompts
@app.get("/family_tree/get") @app.get("/family_tree/get")
async def get_family() -> dict: async def get_family() -> dict:
return dict() return backend.optimizer.family_tree
@app.get("/") @app.get("/")
......
from concurrent.futures import ProcessPoolExecutor
from evolution import GeneticAlgorithm
from models import Llama2, LLMModel
from task import SentimentAnalysis
# def create_model():
# global optimizer
# optimizer = Optimizer(debug=DEBUG)
_evolution_model: LLMModel = None
_evaluation_model: LLMModel = None
# _model_call_type: get_type_hints(LLMModel).get("__call__")
def _setup_models() -> None:
global _evolution_model, _evaluation_model
if _evolution_model is not None:
raise Exception("Evolution model has already been initialized.")
# currently fix model
_evolution_model = Llama2(
model_path="./models/llama-2-13b-chat.Q5_K_M.gguf",
chat=True,
)
if _evaluation_model is not None:
raise Exception("Evaluation model has already been initialized.")
# currently fix model
_evaluation_model = _evolution_model
def _release_models() -> None:
global _evolution_model, _evaluation_model
del _evolution_model
del _evaluation_model
def _call_evolution_model(*args, **kwargs):
return _evolution_model(*args, **kwargs)
def _call_evaluation_model(*args, **kwargs):
return _evaluation_model(*args, **kwargs)
def f():
pass
class MultiProcessOptimizer:
_instance: "MultiProcessOptimizer" = None
_running: bool = False
model_exec: ProcessPoolExecutor = None
def __new__(cls, *args, **kwargs):
# only allow to create one instance (singleton pattern)
if cls._instance is None:
cls._instance = super(MultiProcessOptimizer, cls).__new__(cls)
return cls._instance
def __init__(self, *, debug: bool = False) -> None:
# a flag indicating whether optimizer is currently running
self.debug = debug
def __enter__(self):
# TODO allow to customize optimizer
# create necessary models
# initialize worker processes; only 1 worker since prediction is memory-intensive
# since we only have 1 worker we just save the state in the global namespace which the single worker accesses
self.model_exec = ProcessPoolExecutor(max_workers=1, initializer=_setup_models)
# make sure that initializer is called
self.model_exec.submit(f).result()
evolution_model = lambda *args, **kwargs: self.model_exec.submit(
_call_evolution_model, *args, **kwargs
).result()
evaluation_model = lambda *args, **kwargs: self.model_exec.submit(
_call_evaluation_model, *args, **kwargs
).result()
# currently fix task
task = SentimentAnalysis(
evaluation_model,
"SetFit/sst2",
"SetFit/sst2",
use_grammar=False,
validation_split=f"validation[:{5 if self.debug else 200}]",
test_split="test[:20]" if self.debug else "test",
)
optimizer_class = GeneticAlgorithm
# optimizer_class = DifferentialEvolution
self.optimizer = optimizer_class(
population_size=10,
task=task,
evolution_model=evolution_model,
evaluation_model=evaluation_model,
)
def __exit__(self, exc_type, exc_value, exc_tb):
print(f"Shutting down")
self._submit(_release_models).result()
self.model_exec.shutdown(False)
self.model_exec = None
def _submit(self, fn, *fn_args, **fn_kwargs):
if self.model_exec is None:
raise RuntimeError(
"Cannot access model executor - you have to use this class as a context manager with the with statement first."
)
return self.model_exec.submit(fn, *fn_args, **fn_kwargs)
def run_optimization(self, num_iterations: int) -> str:
self._running = True
self.optimizer.run(num_iterations, debug=self.debug, add_snapshot_dict={})
self._running = False
def get_progress(self):
if hasattr(self.optimizer, "iterations_pbar"):
result = str(self.optimizer.iterations_pbar)
else:
result = "Optimization has not run yet."
return result
...@@ -112,7 +112,10 @@ class EvolutionAlgorithm(PromptOptimization): ...@@ -112,7 +112,10 @@ class EvolutionAlgorithm(PromptOptimization):
P = [initial_prompts] P = [initial_prompts]
# Line 2: # Line 2:
for t in trange(1, num_iterations + 1, desc="iterations", leave=True): self.iterations_pbar = trange(
1, num_iterations + 1, desc="iterations", leave=True
)
for t in self.iterations_pbar:
# Line 3: Selection: select a certain number of prompts from current population as parent prompts # Line 3: Selection: select a certain number of prompts from current population as parent prompts
# pr1,...,prk ∼ Pt−1 # pr1,...,prk ∼ Pt−1
prompts_current_evolution = P[t - 1] prompts_current_evolution = P[t - 1]
...@@ -181,15 +184,15 @@ class EvolutionAlgorithm(PromptOptimization): ...@@ -181,15 +184,15 @@ class EvolutionAlgorithm(PromptOptimization):
class GeneticAlgorithm(EvolutionAlgorithm): class GeneticAlgorithm(EvolutionAlgorithm):
"""The genetic algorithm implemented using LLMs.""" """The genetic algorithm implemented using LLMs."""
# kwargs is just there for convenience, as evolve function of other optimizers might have different inputs
# @register_action(ignore_args=["kwargs"])
@log_calls("Performing prompt evolution using GA") @log_calls("Performing prompt evolution using GA")
def evolve( def evolve(
self, self,
prompt_1: str, prompt_1: str,
prompt_2: str, prompt_2: str,
*, **kwargs,
prompts_current_evolution: list[Prompt], ) -> tuple[str, ModelUsage]:
current_iteration: int,
):
# Following the evolutionary operators in GA, a new candidate prompt is generated through # Following the evolutionary operators in GA, a new candidate prompt is generated through
# a two-step process based on the selected two parents: # a two-step process based on the selected two parents:
# 1) The parent prompts undergo crossover, resulting in a new prompt that # 1) The parent prompts undergo crossover, resulting in a new prompt that
......
...@@ -39,7 +39,6 @@ class LLMModel: ...@@ -39,7 +39,6 @@ class LLMModel:
model: Any model: Any
def __init__(self, chat: bool, model: Any): def __init__(self, chat: bool, model: Any):
self.usage = ModelUsage()
self.chat = chat self.chat = chat
self.model = model self.model = model
...@@ -139,7 +138,6 @@ class Llama2(LLMModel): ...@@ -139,7 +138,6 @@ class Llama2(LLMModel):
response_text = response["choices"][0]["text"] response_text = response["choices"][0]["text"]
# input(f"Response: {response_text}") # input(f"Response: {response_text}")
usage = ModelUsage(**response["usage"]) usage = ModelUsage(**response["usage"])
self.usage += usage
return response_text, usage return response_text, usage
...@@ -198,7 +196,6 @@ class OpenAI(LLMModel): ...@@ -198,7 +196,6 @@ class OpenAI(LLMModel):
**kwargs, **kwargs,
) )
usage = ModelUsage(**response.usage.__dict__) usage = ModelUsage(**response.usage.__dict__)
self.usage += usage
return response.choices[0].message.content, usage return response.choices[0].message.content, usage
else: else:
response = self.openai_client.completions.create( response = self.openai_client.completions.create(
...@@ -213,5 +210,4 @@ class OpenAI(LLMModel): ...@@ -213,5 +210,4 @@ class OpenAI(LLMModel):
**kwargs, **kwargs,
) )
usage = ModelUsage(**response.usage.__dict__) usage = ModelUsage(**response.usage.__dict__)
self.usage += usage
return response.choices[0].text, usage return response.choices[0].text, usage
...@@ -34,7 +34,6 @@ RUNS_DIR = current_directory / "runs" ...@@ -34,7 +34,6 @@ RUNS_DIR = current_directory / "runs"
def initialize_run_directory(model: Callable): def initialize_run_directory(model: Callable):
response, usage = model(None, run_name_prompt, chat=True) response, usage = model(None, run_name_prompt, chat=True)
model.usage -= usage
run_name_match = re.search(r"^\w+$", response, re.MULTILINE) run_name_match = re.search(r"^\w+$", response, re.MULTILINE)
existing_run_names = os.listdir(RUNS_DIR) existing_run_names = os.listdir(RUNS_DIR)
if run_name_match is None or run_name_match.group(0) in existing_run_names: if run_name_match is None or run_name_match.group(0) in existing_run_names:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment