Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
evoprompt
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Deploy
Releases
Package registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Grießhaber Daniel
evoprompt
Commits
b77a62b7
Commit
b77a62b7
authored
11 months ago
by
Grießhaber Daniel
Browse files
Options
Downloads
Patches
Plain Diff
added parent based early stopping monitor
parent
28441f6d
No related branches found
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
evolution.py
+1
-1
1 addition, 1 deletion
evolution.py
opt_types.py
+1
-0
1 addition, 0 deletions
opt_types.py
optimization.py
+22
-8
22 additions, 8 deletions
optimization.py
task.py
+84
-18
84 additions, 18 deletions
task.py
with
108 additions
and
27 deletions
evolution.py
+
1
−
1
View file @
b77a62b7
...
...
@@ -169,7 +169,7 @@ class EvolutionAlgorithm(PromptOptimization):
logger
.
info
(
f
"
Best prompt:
{
p
}
"
)
# We pick the prompt with the highest score on the development set and report its score on the testset.
test_performance
,
_
=
self
.
task
.
evaluate_test
(
p
.
content
)
test_performance
,
_
,
_
=
self
.
task
.
evaluate_test
(
p
.
content
)
logger
.
info
(
"
Best prompt on test set: %s
"
,
test_performance
)
logger
.
info
(
"
Usage (evolution model / evaluation model / total): %s / %s / %s
"
,
...
...
This diff is collapsed.
Click to expand it.
opt_types.py
+
1
−
0
View file @
b77a62b7
...
...
@@ -30,6 +30,7 @@ class Prompt:
content
:
str
score
:
float
usage
:
ModelUsage
evaluation_history
:
list
[
float
]
meta
:
dict
=
field
(
default_factory
=
dict
)
id
:
str
=
field
(
default_factory
=
lambda
:
uuid4
().
hex
)
...
...
This diff is collapsed.
Click to expand it.
optimization.py
+
22
−
8
View file @
b77a62b7
...
...
@@ -58,22 +58,36 @@ class PromptOptimization:
def
reset
(
self
):
self
.
_init
def
evaluate_prompt
(
self
,
prompt
:
str
):
return
self
.
task
.
evaluate_validation
(
prompt
)
def
evaluate_prompt
(
self
,
prompt
:
str
,
parents
:
tuple
[
Prompt
]
|
None
=
None
):
parent_histories
=
(
[
parent
.
evaluation_history
for
parent
in
parents
]
if
parents
is
not
None
else
None
)
return
self
.
task
.
evaluate_validation
(
prompt
,
parent_histories
)
def
add_prompt
(
self
,
prompt
:
str
,
parents
:
tuple
[
Prompt
]
=
None
,
meta
:
dict
=
None
self
,
prompt
:
str
,
parents
:
tuple
[
Prompt
]
|
None
=
None
,
meta
:
dict
|
None
=
None
,
)
->
Prompt
:
score
,
usage
=
self
.
evaluate_prompt
(
prompt
)
prompt
=
Prompt
(
content
=
prompt
,
score
=
score
,
meta
=
meta
,
usage
=
usage
)
score
,
usage
,
history
=
self
.
evaluate_prompt
(
prompt
,
parents
)
prompt_object
=
Prompt
(
content
=
prompt
,
score
=
score
,
meta
=
meta
if
meta
is
not
None
else
{},
usage
=
usage
,
evaluation_history
=
history
,
)
# keep track of prompt
self
.
all_prompts
[
prompt
.
id
]
=
prompt
self
.
family_tree
[
prompt
.
id
]
=
(
self
.
all_prompts
[
prompt
_object
.
id
]
=
prompt
_object
self
.
family_tree
[
prompt
_object
.
id
]
=
(
tuple
(
p
.
id
for
p
in
parents
)
if
parents
is
not
None
else
None
)
return
prompt
return
prompt
_object
def
add_prompts
(
self
,
...
...
This diff is collapsed.
Click to expand it.
task.py
+
84
−
18
View file @
b77a62b7
...
...
@@ -23,7 +23,14 @@ You are given an instruction that describes a task, paired with an input that pr
DatasetDatum
=
dict
class
MomentBasedStopping
:
class
EarlyStoppingMonitor
:
@abstractmethod
def
update
(
self
,
score
:
float
)
->
bool
:
raise
NotImplementedError
class
MomentBasedStopping
(
EarlyStoppingMonitor
):
"""
Watch the first derivative (moment) of the metric to determine when to stop.
"""
...
...
@@ -45,18 +52,60 @@ class MomentBasedStopping:
def
update
(
self
,
score
:
float
)
->
bool
:
# caclulate the current moment (dx/dt)
self
.
num_calls
+=
1
if
self
.
num_calls
<
self
.
start_after
:
return
False
self
.
moment_magnitudes
.
append
(
abs
(
score
-
self
.
last_score
))
self
.
last_score
=
score
if
len
(
self
.
moment_magnitudes
)
<
self
.
patience
:
return
False
if
mean
(
self
.
moment_magnitudes
)
<
self
.
min_moment_magnitude
:
return
True
return
False
class
ParentBaselineBasedStopping
(
EarlyStoppingMonitor
):
def
__init__
(
self
,
parent_histories
:
list
[
list
[
float
]],
*
,
patience
:
int
=
10
,
start_after
:
int
=
20
,
min_improvement
:
float
=
0.001
,
):
self
.
parent_histories
=
parent_histories
self
.
patience
=
patience
self
.
start_after
=
start_after
self
.
min_improvement
=
min_improvement
self
.
num_calls
=
0
self
.
improvement_memory
=
deque
(
maxlen
=
patience
)
def
update
(
self
,
score
:
float
)
->
bool
:
self
.
num_calls
+=
1
if
self
.
num_calls
<
self
.
start_after
:
return
False
if
(
self
.
num_calls
<
self
.
start_after
or
len
(
self
.
moment_magnitudes
)
<
self
.
patience
):
parent_values
=
[
# get the metric value of the parents at the current step
(
parent_history
[
self
.
num_calls
-
1
]
if
len
(
parent_history
)
>=
self
.
num_calls
else
parent_history
[
-
1
]
# extend with last value
)
for
parent_history
in
self
.
parent_histories
]
self
.
improvement_memory
.
append
(
score
-
max
(
parent_values
)
# compare with the best parent
)
if
len
(
self
.
improvement_memory
)
<
self
.
patience
:
return
False
print
(
mean
(
self
.
moment_magnitudes
))
if
mean
(
self
.
moment_magnitudes
)
<
self
.
min_moment_magnitude
:
if
max
(
self
.
improvement_memory
)
<
self
.
min_improvement
:
# if the highest improvement is less than the minimum improvement, we stop
return
True
return
False
...
...
@@ -101,16 +150,31 @@ class Task:
def
_aggregate_result
(
self
,
results
:
list
)
->
float
:
pass
def
evaluate
(
self
,
prompt
:
str
,
dataset
:
Dataset
)
->
tuple
[
float
,
ModelUsage
]:
early_stopping
=
MomentBasedStopping
(
patience
=
len
(
dataset
)
//
20
,
start_after
=
len
(
dataset
)
//
5
,
)
def
evaluate
(
self
,
prompt
:
str
,
dataset
:
Dataset
,
parent_histories
:
list
[
list
[
float
]]
|
None
=
None
,
)
->
tuple
[
float
,
ModelUsage
,
list
[
float
]]:
early_stopping
:
EarlyStoppingMonitor
early_stopping_params
=
{
"
patience
"
:
max
(
len
(
dataset
)
//
20
,
5
),
"
start_after
"
:
max
(
len
(
dataset
)
//
5
,
5
),
}
if
parent_histories
is
not
None
:
early_stopping
=
ParentBaselineBasedStopping
(
parent_histories
,
**
early_stopping_params
)
else
:
early_stopping
=
MomentBasedStopping
(
**
early_stopping_params
)
results
:
list
=
[]
dataset_iterator
:
tqdm
[
DatasetDatum
]
=
tqdm
(
dataset
,
desc
=
"
evaluating prompt
"
,
leave
=
False
)
evaluation_usage
=
ModelUsage
()
evaluation_history
=
[]
for
datum
in
dataset_iterator
:
result
,
usage
=
self
.
_evaluate_sample
(
prompt
,
datum
)
...
...
@@ -120,18 +184,20 @@ class Task:
{
self
.
metric_name
:
f
"
{
current_metric
*
100
:
.
1
f
}
%
"
}
)
evaluation_usage
+=
usage
evaluation_history
.
append
(
current_metric
)
if
early_stopping
.
update
(
current_metric
):
logger
.
info
(
f
"
Early stopping after
{
len
(
results
)
}
samples with
{
self
.
metric_name
}
of
{
current_metric
*
100
:
.
1
f
}
%
"
)
break
return
self
.
_aggregate_result
(
results
),
evaluation_usage
return
self
.
_aggregate_result
(
results
),
evaluation_usage
,
evaluation_history
@log_calls
(
"
Evaluating validation dataset
"
)
@lru_cache
(
maxsize
=
None
)
def
evaluate_validation
(
self
,
prompt
:
str
):
return
self
.
evaluate
(
prompt
,
self
.
validation_dataset
)
def
evaluate_validation
(
self
,
prompt
:
str
,
parent_histories
:
list
[
list
[
float
]]
|
None
=
None
):
return
self
.
evaluate
(
prompt
,
self
.
validation_dataset
,
parent_histories
)
@log_calls
(
"
Evaluating test dataset
"
)
def
evaluate_test
(
self
,
prompt
:
str
):
...
...
@@ -207,7 +273,7 @@ class SentimentAnalysis(Task):
break
else
:
logger
.
warning
(
f
"
Invalid answer:
{
response
}
"
)
return
"
failed
"
return
"
failed
"
,
usage
classification_result
=
(
"
incorrect
"
if
answer_label
!=
datum
[
"
label
"
]
else
"
correct
"
...
...
@@ -333,7 +399,7 @@ class QuestionAnswering(Task):
def
_aggregate_result
(
self
,
results
:
list
[
float
])
->
float
:
return
sum
(
results
)
/
len
(
results
)
def
evaluate
(
self
,
prompt
:
str
,
dataset
:
Dataset
)
->
tuple
[
float
,
ModelUsage
]
:
def
evaluate
(
self
,
prompt
:
str
,
dataset
:
Dataset
):
def
replace_symbol_for_grammar
(
sample
:
DatasetDatum
):
symbol_replacement_mapping
=
{
"
\u2013
"
:
"
-
"
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment