From 14c65c4d24894acd37905654897d4eb0dcbaa4d2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de>
Date: Mon, 18 Mar 2024 08:58:22 +0100
Subject: [PATCH] add model usage counter for individual prompts and models

---
 evo_types.py | 55 ++++++++++++++++++++++++++++++
 main.py      | 47 ++++++++++++++++----------
 models.py    | 94 +++++++++++++++++++++++++++++-----------------------
 task.py      | 46 ++++++++++++++-----------
 utils.py     | 39 +++++-----------------
 5 files changed, 170 insertions(+), 111 deletions(-)
 create mode 100644 evo_types.py

diff --git a/evo_types.py b/evo_types.py
new file mode 100644
index 0000000..55b6a71
--- /dev/null
+++ b/evo_types.py
@@ -0,0 +1,55 @@
+import json
+from dataclasses import dataclass, field, is_dataclass
+from uuid import uuid4
+
+from llama_cpp.llama_types import CompletionUsage
+
+
+@dataclass(frozen=True)
+class ModelUsage:
+
+    prompt_tokens: int = 0
+    completion_tokens: int = 0
+    total_tokens: int = 0
+
+    def __add__(self, other: "ModelUsage") -> "ModelUsage":
+        return ModelUsage(
+            prompt_tokens=self.prompt_tokens + other.prompt_tokens,
+            completion_tokens=self.completion_tokens + other.completion_tokens,
+            total_tokens=self.total_tokens + other.total_tokens,
+        )
+
+    def __sub__(self, other: "ModelUsage") -> "ModelUsage":
+        return ModelUsage(
+            prompt_tokens=self.prompt_tokens - other.prompt_tokens,
+            completion_tokens=self.completion_tokens - other.completion_tokens,
+            total_tokens=self.total_tokens - other.total_tokens,
+        )
+
+
+@dataclass(frozen=True)
+class Prompt:
+    content: str
+    score: float
+    gen: int
+    usage: ModelUsage
+    id: str = field(default_factory=lambda: uuid4().hex)
+    meta: dict = field(default_factory=dict)
+
+    def __str__(self) -> str:
+        return self.content
+
+    def __hash__(self) -> int:
+        return (
+            hash(self.content)
+            + hash(self.score)
+            + hash(self.gen)
+            + hash(frozenset(self.meta.items()))
+        )
+
+
+class EvoTypeEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if is_dataclass(obj):
+            return obj.__dict__
+        return json.JSONEncoder.default(self, obj)
diff --git a/main.py b/main.py
index 08a5b06..d4cd0e6 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,5 @@
 import os
 from functools import lru_cache
-from pathlib import Path
 from typing import Any
 
 from dotenv import load_dotenv
@@ -8,9 +7,10 @@ from numpy.random import choice
 from tqdm import trange
 
 from cli import argument_parser
+from evo_types import ModelUsage, Prompt
 from models import Llama2, OpenAI
 from task import QuestionAnswering, SentimentAnalysis
-from utils import Prompt, initialize_run_directory, log_calls, logger, save_snapshot
+from utils import initialize_run_directory, log_calls, logger, save_snapshot
 
 
 def conv2bool(_str: Any):
@@ -25,8 +25,6 @@ def conv2bool(_str: Any):
 
 load_dotenv()
 
-current_directory = Path(__file__).resolve().parent
-
 PARAPHRASE_PROMPT = """
 Below is an instruction that describes a task. Write a response that paraphrases the instruction. Only output the paraphrased instruction bracketed in <prompt> and </prompt>.
 
@@ -40,13 +38,15 @@ Below is an instruction that describes a task. Write a response that paraphrases
 
 @log_calls("Paraphrasing prompts")
 def paraphrase_prompts(prompt: str, n: int):
+    total_usage = ModelUsage()
     paraphrases = []
     for _ in range(n):
-        paraphrase = evolution_model(
+        paraphrase, usage = evolution_model(
             prompt=PARAPHRASE_PROMPT.format(instruction=prompt)
         )
+        total_usage += usage
         paraphrases.append(paraphrase)
-    return paraphrases
+    return paraphrases, usage
 
 
 @log_calls("Performing selection")
@@ -100,18 +100,18 @@ def evolution_ga(prompt1: str, prompt2: str):
     #   in which random alterations are made to some of its content.
     # Based on this two-step process, we design instructions, guiding LLMs to
     # generate a new prompt based on these steps to perform Evo(·) in Algorithm 1.
-    evolved_prompt = evolution_model(
+    evolved_prompt, usage = evolution_model(
         prompt=GA_PROMPT.format(prompt1=prompt1, prompt2=prompt2)
     )
     if "<prompt>" in evolved_prompt:
         evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
-    return evolved_prompt
+    return evolved_prompt, usage
 
 
 @log_calls("Performing prompt evolution using DE")
 def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str):
     # TODO add comment from paper
-    evolved_prompt = evolution_model(
+    evolved_prompt, usage = evolution_model(
         prompt=DE_PROMPT.format(
             prompt1=prompt1,
             prompt2=prompt2,
@@ -121,7 +121,7 @@ def evolution_de(prompt1: str, prompt2: str, basic_prompt: str, best_prompt: str
     )
     if "<prompt>" in evolved_prompt:
         evolved_prompt = evolved_prompt.split("<prompt>")[1].split("</prompt>")[0]
-    return evolved_prompt
+    return evolved_prompt, usage
 
 
 @log_calls("Updating prompts")
@@ -146,13 +146,18 @@ def update(prompts: list[str], N: int):
 
 
 def run_episode(evo_alg_str: str, debug: bool = False):
+    # model usage for evolution of prompts
+    evolution_usage = ModelUsage()
+    # model usage for evaluating prompts
+    evaluation_usage = ModelUsage()
     # Algorithm 1 Discrete prompt optimization: EVOPROMPT
 
     # Require:
     # - Size of population
     N = 3 if debug else 10
     # - Initial prompts P0 = {p1, p2, . . . , pN }
-    paraphrases = paraphrase_prompts(task.base_prompt, n=N - 1)
+    paraphrases, usage = paraphrase_prompts(task.base_prompt, n=N - 1)
+    evolution_usage += usage
     # the initial population
     initial_population = [task.base_prompt] + paraphrases
 
@@ -166,15 +171,17 @@ def run_episode(evo_alg_str: str, debug: bool = False):
 
     # Line 1: Initial evaluation scores: S0 ← {si = fD (pi )|i ∈ [1, N ]}
     # the current population's scores
-    population_scores = [f_D(p) for p in initial_population]
+    initial_population_scores: list[float] = [f_D(p) for p in initial_population]
 
     # all_prompts contains a list of Prompt objects that took part in this run at some time
     # converting prompts to Prompt object
     all_prompts: dict[str, Prompt] = {
         prompt.id: prompt
         for prompt in [
-            Prompt(p, score=score, gen=0)
-            for (p, score) in zip(initial_population, population_scores)
+            Prompt(p, score=score, gen=0, usage=usage)
+            for (p, (score, usage)) in zip(
+                initial_population, initial_population_scores
+            )
         ]
     }
 
@@ -207,20 +214,22 @@ def run_episode(evo_alg_str: str, debug: bool = False):
             # Line 4: Evolution: generate a new prompt based on the selected parent prompts by leveraging LLM to perform evolutionary operators
             # p′i ←Evo(pr1,...,prk)
             if evo_alg_str == "ga":
-                p_i = evolution_ga(pr1, pr2)
+                p_i, usage = evolution_ga(pr1, pr2)
             elif evo_alg_str == "de":
-                p_i = evolution_de(
+                p_i, usage = evolution_de(
                     pr1,
                     pr2,
                     prompts_current_evolution[i],
                     best_prompt_current_evolution,
                 )
+            evolution_usage += usage
 
             # Line 5: Evaluation
             # s′_i ← f(p′i,D)
-            s_i = f_D(p_i)
+            s_i, usage = f_D(p_i)
+            evaluation_usage += usage
 
-            evolved_prompt = Prompt(content=p_i, score=s_i, gen=t)
+            evolved_prompt = Prompt(content=p_i, score=s_i, gen=t, usage=usage)
 
             # keep track of genealogy
             family_tree[evolved_prompt.id] = (pr1.id, pr2.id)
@@ -258,6 +267,8 @@ def run_episode(evo_alg_str: str, debug: bool = False):
         N,
         task,
         evolution_model,
+        evaluation_usage,
+        evolution_usage,
         options.__dict__,
     )
     # Line 8: Return the best prompt, p∗, among the final population PT :
diff --git a/models.py b/models.py
index 80e5868..b58eed5 100644
--- a/models.py
+++ b/models.py
@@ -4,10 +4,22 @@ from typing import Any
 import openai
 from llama_cpp import Llama
 
+from evo_types import ModelUsage
+
 current_directory = Path(__file__).resolve().parent
 
 
-class Llama2:
+class LLMModel:
+    chat: bool
+    model: Any
+
+    def __init__(self, chat: bool, model: Any):
+        self.usage = ModelUsage()
+        self.chat = chat
+        self.model = model
+
+
+class Llama2(LLMModel):
     """Loads and queries a Llama2 model."""
 
     def __init__(
@@ -20,12 +32,8 @@ class Llama2:
         verbose: bool = False,
         **kwargs
     ) -> None:
-        super().__init__()
-
-        self.chat = chat
-
         # initialize model
-        self.model = Llama(
+        model = Llama(
             model_path,
             chat_format="llama-2",
             verbose=verbose,
@@ -34,6 +42,7 @@ class Llama2:
             n_ctx=n_ctx,
             **kwargs,
         )
+        super().__init__(chat, model)
 
     def __call__(
         self,
@@ -42,12 +51,12 @@ class Llama2:
         stop: str = "</prompt>",
         max_tokens: int = 200,
         **kwargs: Any
-    ) -> str:
+    ) -> tuple[str, ModelUsage]:
         if chat is None:
             chat = self.chat
 
         if chat:
-            return self.model.create_chat_completion(
+            response = self.model.create_chat_completion(
                 # TODO add system message?
                 messages=[
                     {
@@ -58,23 +67,26 @@ class Llama2:
                 stop=stop,
                 max_tokens=max_tokens,
                 **kwargs,
-            )["choices"][0]["message"]["content"]
+            )
+            usage = ModelUsage(**response["usage"])
+            self.usage += usage
+            return response["choices"][0]["message"]["content"], usage
         else:
-            return self.model.create_completion(
+            response = self.model.create_completion(
                 prompt=prompt, stop=stop, max_tokens=max_tokens, **kwargs
-            )["choices"][0]["text"]
+            )
+            usage = ModelUsage(**response["usage"])
+            self.usage += usage
+            return response["choices"][0]["text"], usage
 
 
-class OpenAI:
+class OpenAI(LLMModel):
     """Queries an OpenAI model using its API."""
 
     def __init__(
         self, model: str, chat: bool = False, verbose: bool = False, **kwargs
     ) -> None:
-        super().__init__()
-
-        self.chat = chat
-        self.model = model
+        super().__init__(chat, model)
 
         # initialize model
         self.openai_client = openai.OpenAI(**kwargs)
@@ -86,36 +98,34 @@ class OpenAI:
         stop: str = "</prompt>",
         max_tokens: int = 200,
         **kwargs: Any
-    ) -> str:
+    ) -> tuple[str, ModelUsage]:
         if chat is None:
             chat = self.chat
 
         if chat:
-            return (
-                self.openai_client.chat.completions.create(
-                    model=self.model,
-                    messages=[
-                        {
-                            "role": "user",
-                            "content": prompt,
-                        }
-                    ],
-                    stop=stop,
-                    max_tokens=max_tokens,
-                    **kwargs,
-                )
-                .choices[0]
-                .message.content
+            response = self.openai_client.chat.completions.create(
+                model=self.model,
+                messages=[
+                    {
+                        "role": "user",
+                        "content": prompt,
+                    }
+                ],
+                stop=stop,
+                max_tokens=max_tokens,
+                **kwargs,
             )
+            usage = ModelUsage(**response.usage.__dict__)
+            self.usage += usage
+            return response.choices[0].message.content, usage
         else:
-            return (
-                self.openai_client.completions.create(
-                    model=self.model,
-                    prompt=prompt,
-                    stop=stop,
-                    max_tokens=max_tokens,
-                    **kwargs,
-                )
-                .choices[0]
-                .message.content
+            response = self.openai_client.completions.create(
+                model=self.model,
+                prompt=prompt,
+                stop=stop,
+                max_tokens=max_tokens,
+                **kwargs,
             )
+            usage = ModelUsage(**response.usage.__dict__)
+            self.usage += usage
+            return response.choices[0].message.content, usage
diff --git a/task.py b/task.py
index 27ddcfe..332ffb8 100644
--- a/task.py
+++ b/task.py
@@ -10,7 +10,7 @@ from llama_cpp import LlamaGrammar
 from tqdm import tqdm
 
 from models import Llama2, OpenAI
-from utils import log_calls, logger
+from utils import ModelUsage, log_calls, logger
 
 CLASSIFICATION_PROMPT = """
 Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -33,8 +33,8 @@ class Task:
         test_dataset: str,
         *,
         use_grammar: bool,
-        validation_split: str = None,
-        test_split: str = None,
+        validation_split: str | None = None,
+        test_split: str | None = None,
     ) -> None:
         self.model = model
 
@@ -47,11 +47,11 @@ class Task:
         self.test_dataset = load_dataset(test_dataset, split=test_split)
 
     @abstractmethod
-    def predict(self, prompt: str, *args, **kwargs):
+    def predict(self, prompt: str, *args, **kwargs) -> tuple[str, ModelUsage]:
         pass
 
     @abstractmethod
-    def _evaluate(self, prompt: str, dataset):
+    def _evaluate(self, prompt: str, dataset) -> tuple[float, ModelUsage]:
         pass
 
     @log_calls("Evaluating validation dataset")
@@ -95,8 +95,8 @@ class SentimentAnalysis(Task):
         test_dataset: str,
         *,
         use_grammar: bool,
-        validation_split: str = None,
-        test_split: str = None,
+        validation_split: str | None = None,
+        test_split: str | None = None,
     ) -> None:
         super().__init__(
             model,
@@ -110,7 +110,7 @@ class SentimentAnalysis(Task):
     def predict(self, prompt: str, text: str):
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
-        response = self.model(
+        response, usage = self.model(
             prompt=CLASSIFICATION_PROMPT.format(instruction=prompt, input=text),
             grammar=sa_grammar_fn() if self.use_grammar else None,
             chat=False if self.use_grammar else True,
@@ -121,17 +121,19 @@ class SentimentAnalysis(Task):
             response = response.strip()
         # input(f"*** PREDICTION***\n\tText: {text}\n\tSentiment: {response}")
 
-        return response
+        return response, usage
 
     def _evaluate(self, prompt: str, dataset: Dataset):
         sst2_labels = {"negative": 0, "positive": 1}
 
         results: DefaultDict[str, int] = defaultdict(int)
         dataset_iterator = tqdm(dataset, desc="evaluating prompt", leave=False)
+        evaluation_usage = ModelUsage()
 
         for datum in dataset_iterator:
-            response = self.predict(prompt=prompt, text=datum["text"]).lower()
-
+            response, usage = self.predict(prompt=prompt, text=datum["text"])
+            response = response.lower()
+            evaluation_usage += usage
             if self.use_grammar:
                 # model output is from label space
                 answer_label = sst2_labels[response]
@@ -153,7 +155,7 @@ class SentimentAnalysis(Task):
             dataset_iterator.set_postfix(results)
 
         accuracy = results["correct"] / sum(results.values())
-        return accuracy
+        return accuracy, evaluation_usage
 
     @property
     def metric_name(self):
@@ -208,8 +210,8 @@ class QuestionAnswering(Task):
         test_dataset: str,
         *,
         use_grammar: bool,
-        validation_split: str = None,
-        test_split: str = None,
+        validation_split: str | None = None,
+        test_split: str | None = None,
     ) -> None:
         self.evaluation_prompt = """
         Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -239,7 +241,6 @@ class QuestionAnswering(Task):
 
     def predict(self, prompt: str, context: str, question: str):
         # run model for inference
-
         grammar = None
         if self.use_grammar:
             # context-sensitive grammar
@@ -253,7 +254,7 @@ class QuestionAnswering(Task):
                     exc_info=e,
                 )
 
-        response = self.model(
+        response, usage = self.model(
             prompt=self.evaluation_prompt.format(
                 instruction=prompt,
                 context=context,
@@ -268,9 +269,11 @@ class QuestionAnswering(Task):
             response = response.strip()
         # input(f"*** PREDICTION***\n\tContext: {context}\n\tQuestion: {question}\n\tAnswer: {response}")
 
-        return response
+        return response, usage
 
     def _evaluate(self, prompt: str, dataset: Dataset):
+        evaluation_usage = ModelUsage()
+
         def replace_symbol_for_grammar(sample: Mapping):
             sample["context"] = sample["context"].replace("–", "-")
             sample["answers"]["text"] = [
@@ -288,9 +291,12 @@ class QuestionAnswering(Task):
         f1 = 0.0
         em = 0
         for datum in dataset_iterator:
-            answer = self.predict(
-                prompt, context=datum["context"], question=datum["question"]
+            answer, usage = self.predict(
+                prompt,
+                context=datum["context"],
+                question=datum["question"],
             ).lower()
+            evaluation_usage += usage
 
             num_samples += 1
             result = self.metric.compute(
@@ -304,7 +310,7 @@ class QuestionAnswering(Task):
                 {"f1": f1 / num_samples, "em": em / num_samples}
             )
 
-        return f1 / num_samples
+        return f1 / num_samples, evaluation_usage
 
     @property
     def metric_name(self):
diff --git a/utils.py b/utils.py
index 128aa84..b1c9918 100644
--- a/utils.py
+++ b/utils.py
@@ -2,7 +2,6 @@ import inspect
 import json
 import logging
 import re
-from dataclasses import dataclass, field
 from functools import wraps
 from pathlib import Path
 from pprint import pformat
@@ -10,6 +9,7 @@ from textwrap import dedent, indent
 from typing import Any, Callable
 from uuid import uuid4
 
+from evo_types import EvoTypeEncoder, ModelUsage, Prompt
 from models import Llama2, OpenAI
 
 current_directory = Path(__file__).resolve().parent
@@ -23,7 +23,8 @@ Only return the name without any text before or after.""".strip()
 
 
 def initialize_run_directory(model: OpenAI | Llama2):
-    response = model(run_name_prompt, chat=True)
+    response, usage = model(run_name_prompt, chat=True)
+    model.usage -= usage
     run_name_match = re.search(r"^\w+$", response, re.MULTILINE)
     if run_name_match is None:
         run_name = uuid4().hex
@@ -110,33 +111,6 @@ class log_calls:
         return arguments
 
 
-@dataclass(frozen=True)
-class Prompt:
-    content: str
-    score: float
-    gen: int
-    id: str = field(default_factory=lambda: uuid4().hex)
-    meta: dict = field(default_factory=dict)
-
-    def __str__(self) -> str:
-        return self.content
-
-    def __hash__(self) -> int:
-        return (
-            hash(self.content)
-            + hash(self.score)
-            + hash(self.gen)
-            + hash(frozenset(self.meta.items()))
-        )
-
-
-class PromptEncoder(json.JSONEncoder):
-    def default(self, obj):
-        if isinstance(obj, Prompt):
-            return obj.__dict__
-        return json.JSONEncoder.default(self, obj)
-
-
 def save_snapshot(
     run_directory: Path,
     all_prompts: list[Prompt],
@@ -146,9 +120,10 @@ def save_snapshot(
     N: int,
     task,
     model: Llama2 | OpenAI,
+    evaluation_usage: ModelUsage,
+    evolution_usage: ModelUsage,
     run_options: dict[str, Any],
 ):
-    import json
 
     with open(run_directory / "snapshot.json", "w") as f:
         json.dump(
@@ -166,11 +141,13 @@ def save_snapshot(
                     "use_grammar": task.use_grammar,
                 },
                 "model": {"name": model.__class__.__name__},
+                "evaluation_usage": evaluation_usage,
+                "evolution_usage": evolution_usage,
                 "run_options": run_options,
             },
             f,
             indent=4,
-            cls=PromptEncoder,
+            cls=EvoTypeEncoder,
         )
 
 
-- 
GitLab