diff --git a/main.py b/main.py index 66dd0a23510adcf5d6ec4c4e69b3551b0bb486e2..0812c2ce046352876ac3ac7a582b98a28e4a3773 100644 --- a/main.py +++ b/main.py @@ -267,15 +267,16 @@ def run_episode(evo_alg_str: str, debug: bool = False): S.append(population_scores) save_snapshot(family_tree, P, S, T, N) - save_genealogy(family_tree, P, S) - save_family_tree_visualization(family_tree, P, S) + save_genealogy(family_tree, P, metric_name=task.metric_name) + save_family_tree_visualization(family_tree, P, metric_name=task.metric_name) # Line 8: Return the best prompt, p‚àó, among the final population PT : # p‚àó ‚Üê argmaxp‚ààPT f(p, D) - p = max(range(N), key=lambda i: population_scores[i]) - logger.info(f"Best prompt: {population[p]}") + p = max(range(N), key=lambda i: S[-1][i]) + logger.info(f"Best prompt: {P[-1][p]}") # We pick the prompt with the highest score on the development set and report its score on the testset. - task.evaluate_test(P[p]) + test_performance = task.evaluate_test(P[-1][p]) + logger.info(f"Best prompt on test set: {test_performance}") if __name__ == "__main__": diff --git a/task.py b/task.py index 87374515ad5312a1bf29448589db8bdc5165165f..05ffc028bb5759ae810bd8b7cc2484c394c9ed52 100644 --- a/task.py +++ b/task.py @@ -50,6 +50,11 @@ class Task: def evaluate_test(self, prompt: str): return self._evaluate(prompt, self.test_dataset) + @property + @abstractmethod + def metric_name(self): + pass + @property @abstractmethod def base_prompt(self): @@ -101,6 +106,10 @@ class SentimentAnalysis(Task): accuracy = results["correct"] / sum(results.values()) return accuracy + @property + def metric_name(self): + return "accuracy" + @property def base_prompt(self): # from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning @@ -167,6 +176,10 @@ class QuestionAnswering(Task): return f1 / num_samples # , em/num_samples + @property + def metric_name(self): + return "f1" + @property def base_prompt(self): # TODO find good prompt diff --git a/utils.py b/utils.py index d568928871812b561a985d5bc5dd08aeaa3af0d2..b30ed6ea18854b006d09dbbdadd73fd11da62b2c 100644 --- a/utils.py +++ b/utils.py @@ -121,7 +121,7 @@ class PromptEncoder(json.JSONEncoder): @log_calls("Saving genealogy") -def save_genealogy(family_tree, P, S): +def save_genealogy(family_tree, P, metric_name: str): node_id = lambda prompt: hash(prompt) dot = Digraph(comment="Genealogy") for t in range(len(P)): @@ -135,7 +135,7 @@ def save_genealogy(family_tree, P, S): continue cluster.node( f"{node_id(prompt)}", - label=f"{fill(str(prompt))}\nAccuracy: {prompt.score}", + label=f"{fill(str(prompt))}\{metric_name}: {prompt.score}", ) if parents is not None: parent_1, parent_2 = parents @@ -145,7 +145,7 @@ def save_genealogy(family_tree, P, S): dot.render(run_directory / "genealogy", engine="dot") -def save_family_tree_visualization(family_tree, P, S): +def save_family_tree_visualization(family_tree, P, metric_name: str): dot = Digraph(comment="Genealogy") node_id = lambda prompt: hash(prompt) prompt_generations = lambda prompt: [ @@ -155,13 +155,14 @@ def save_family_tree_visualization(family_tree, P, S): """ {prompt} ‚àà P{{{prompt_generations}}} - Accuracy: {score}""" + {metric_name}: {score}""" ) for prompt, parents in family_tree: node_text = node_template.format( prompt=fill(str(prompt)), prompt_generations=", ".join(prompt_generations(prompt)), score=prompt.score, + metric_name=metric_name ) dot.node( f"{node_id(prompt)}",