From 9cfca348494b66ddccd2bc07a747c43be73462cc Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt <maximilian.schmidt@ims.uni-stuttgart.de> Date: Tue, 6 Feb 2024 14:18:03 +0100 Subject: [PATCH] Add task metric names and fix best prompt --- main.py | 11 ++++++----- task.py | 13 +++++++++++++ utils.py | 9 +++++---- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 66dd0a2..0812c2c 100644 --- a/main.py +++ b/main.py @@ -267,15 +267,16 @@ def run_episode(evo_alg_str: str, debug: bool = False): S.append(population_scores) save_snapshot(family_tree, P, S, T, N) - save_genealogy(family_tree, P, S) - save_family_tree_visualization(family_tree, P, S) + save_genealogy(family_tree, P, metric_name=task.metric_name) + save_family_tree_visualization(family_tree, P, metric_name=task.metric_name) # Line 8: Return the best prompt, p‚àó, among the final population PT : # p‚àó ‚Üê argmaxp‚ààPT f(p, D) - p = max(range(N), key=lambda i: population_scores[i]) - logger.info(f"Best prompt: {population[p]}") + p = max(range(N), key=lambda i: S[-1][i]) + logger.info(f"Best prompt: {P[-1][p]}") # We pick the prompt with the highest score on the development set and report its score on the testset. - task.evaluate_test(P[p]) + test_performance = task.evaluate_test(P[-1][p]) + logger.info(f"Best prompt on test set: {test_performance}") if __name__ == "__main__": diff --git a/task.py b/task.py index 8737451..05ffc02 100644 --- a/task.py +++ b/task.py @@ -50,6 +50,11 @@ class Task: def evaluate_test(self, prompt: str): return self._evaluate(prompt, self.test_dataset) + @property + @abstractmethod + def metric_name(self): + pass + @property @abstractmethod def base_prompt(self): @@ -101,6 +106,10 @@ class SentimentAnalysis(Task): accuracy = results["correct"] / sum(results.values()) return accuracy + @property + def metric_name(self): + return "accuracy" + @property def base_prompt(self): # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning @@ -167,6 +176,10 @@ class QuestionAnswering(Task): return f1 / num_samples # , em/num_samples + @property + def metric_name(self): + return "f1" + @property def base_prompt(self): # TODO find good prompt diff --git a/utils.py b/utils.py index d568928..b30ed6e 100644 --- a/utils.py +++ b/utils.py @@ -121,7 +121,7 @@ class PromptEncoder(json.JSONEncoder): @log_calls("Saving genealogy") -def save_genealogy(family_tree, P, S): +def save_genealogy(family_tree, P, metric_name: str): node_id = lambda prompt: hash(prompt) dot = Digraph(comment="Genealogy") for t in range(len(P)): @@ -135,7 +135,7 @@ def save_genealogy(family_tree, P, S): continue cluster.node( f"{node_id(prompt)}", - label=f"{fill(str(prompt))}\nAccuracy: {prompt.score}", + label=f"{fill(str(prompt))}\{metric_name}: {prompt.score}", ) if parents is not None: parent_1, parent_2 = parents @@ -145,7 +145,7 @@ def save_genealogy(family_tree, P, S): dot.render(run_directory / "genealogy", engine="dot") -def save_family_tree_visualization(family_tree, P, S): +def save_family_tree_visualization(family_tree, P, metric_name: str): dot = Digraph(comment="Genealogy") node_id = lambda prompt: hash(prompt) prompt_generations = lambda prompt: [ @@ -155,13 +155,14 @@ def save_family_tree_visualization(family_tree, P, S): """ {prompt} ‚àà P{{{prompt_generations}}} - Accuracy: {score}""" + {metric_name}: {score}""" ) for prompt, parents in family_tree: node_text = node_template.format( prompt=fill(str(prompt)), prompt_generations=", ".join(prompt_generations(prompt)), score=prompt.score, + metric_name=metric_name ) dot.node( f"{node_id(prompt)}", -- GitLab