Skip to content
Snippets Groups Projects
Commit 9cfca348 authored by Max Kimmich's avatar Max Kimmich
Browse files

Add task metric names and fix best prompt

parent bcefa7e4
No related branches found
No related tags found
No related merge requests found
......@@ -267,15 +267,16 @@ def run_episode(evo_alg_str: str, debug: bool = False):
S.append(population_scores)
save_snapshot(family_tree, P, S, T, N)
save_genealogy(family_tree, P, S)
save_family_tree_visualization(family_tree, P, S)
save_genealogy(family_tree, P, metric_name=task.metric_name)
save_family_tree_visualization(family_tree, P, metric_name=task.metric_name)
# Line 8: Return the best prompt, p∗, among the final population PT :
# p∗ ← argmaxp∈PT f(p, D)
p = max(range(N), key=lambda i: population_scores[i])
logger.info(f"Best prompt: {population[p]}")
p = max(range(N), key=lambda i: S[-1][i])
logger.info(f"Best prompt: {P[-1][p]}")
# We pick the prompt with the highest score on the development set and report its score on the testset.
task.evaluate_test(P[p])
test_performance = task.evaluate_test(P[-1][p])
logger.info(f"Best prompt on test set: {test_performance}")
if __name__ == "__main__":
......
......@@ -50,6 +50,11 @@ class Task:
def evaluate_test(self, prompt: str):
return self._evaluate(prompt, self.test_dataset)
@property
@abstractmethod
def metric_name(self):
pass
@property
@abstractmethod
def base_prompt(self):
......@@ -101,6 +106,10 @@ class SentimentAnalysis(Task):
accuracy = results["correct"] / sum(results.values())
return accuracy
@property
def metric_name(self):
return "accuracy"
@property
def base_prompt(self):
# from the paper: RLPROMPT: Optimizing Discrete Text Prompts with Reinforcement Learning
......@@ -167,6 +176,10 @@ class QuestionAnswering(Task):
return f1 / num_samples # , em/num_samples
@property
def metric_name(self):
return "f1"
@property
def base_prompt(self):
# TODO find good prompt
......
......@@ -121,7 +121,7 @@ class PromptEncoder(json.JSONEncoder):
@log_calls("Saving genealogy")
def save_genealogy(family_tree, P, S):
def save_genealogy(family_tree, P, metric_name: str):
node_id = lambda prompt: hash(prompt)
dot = Digraph(comment="Genealogy")
for t in range(len(P)):
......@@ -135,7 +135,7 @@ def save_genealogy(family_tree, P, S):
continue
cluster.node(
f"{node_id(prompt)}",
label=f"{fill(str(prompt))}\nAccuracy: {prompt.score}",
label=f"{fill(str(prompt))}\{metric_name}: {prompt.score}",
)
if parents is not None:
parent_1, parent_2 = parents
......@@ -145,7 +145,7 @@ def save_genealogy(family_tree, P, S):
dot.render(run_directory / "genealogy", engine="dot")
def save_family_tree_visualization(family_tree, P, S):
def save_family_tree_visualization(family_tree, P, metric_name: str):
dot = Digraph(comment="Genealogy")
node_id = lambda prompt: hash(prompt)
prompt_generations = lambda prompt: [
......@@ -155,13 +155,14 @@ def save_family_tree_visualization(family_tree, P, S):
"""
{prompt}
∈ P{{{prompt_generations}}}
Accuracy: {score}"""
{metric_name}: {score}"""
)
for prompt, parents in family_tree:
node_text = node_template.format(
prompt=fill(str(prompt)),
prompt_generations=", ".join(prompt_generations(prompt)),
score=prompt.score,
metric_name=metric_name
)
dot.node(
f"{node_id(prompt)}",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment