Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ba2e7ff0
Commit
ba2e7ff0
authored
Feb 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rewrite method in callbacks
parent
4fa2837e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
19 deletions
+32
-19
example_mnist.py
example_mnist.py
+1
-0
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+12
-8
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+10
-2
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+1
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+6
-6
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+2
-2
No files found.
example_mnist.py
View file @
ba2e7ff0
...
...
@@ -85,6 +85,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
#step_per_epoch = 3
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/base.py
View file @
ba2e7ff0
...
...
@@ -24,6 +24,7 @@ class Callback(object):
def
before_train
(
self
):
self
.
graph
=
tf
.
get_default_graph
()
self
.
sess
=
tf
.
get_default_session
()
self
.
epoch_num
=
0
self
.
_before_train
()
def
_before_train
(
self
):
...
...
@@ -46,22 +47,25 @@ class Callback(object):
"""
def
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
self
.
global_step
=
get_global_step
()
self
.
_trigger_epoch
()
def
_trigger_epoch
(
self
):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class
PeriodicCallback
(
Callback
):
def
__init__
(
self
,
period
):
self
.
__period
=
period
self
.
epoch_num
=
0
self
.
period
=
period
def
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
if
self
.
epoch_num
%
self
.
__period
==
0
:
self
.
global_step
=
get_global_step
()
self
.
_trigger
()
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
self
.
_trigger_periodic
()
@
abstractmethod
def
_trigger
(
self
):
def
_trigger
_periodic
(
self
):
pass
tensorpack/callbacks/common.py
View file @
ba2e7ff0
...
...
@@ -24,7 +24,7 @@ class PeriodicSaver(PeriodicCallback):
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_trigger
(
self
):
def
_trigger
_periodic
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
...
...
@@ -40,13 +40,16 @@ class SummaryWriter(Callback):
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
self
.
epoch_num
=
0
def
trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
epoch_num
+=
1
# check if there is any summary
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
printed_tag
=
set
()
for
val
in
summary
.
value
:
#print val.tag
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
...
...
@@ -54,7 +57,12 @@ class SummaryWriter(Callback):
assert
val
.
WhichOneof
(
'value'
)
==
'simple_value'
,
\
'Cannot print summary {}: not a simple_value summary!'
.
format
(
val
.
tag
)
logger
.
info
(
'{}: {:.4f}'
.
format
(
val
.
tag
,
val
.
simple_value
))
printed_tag
.
add
(
val
.
tag
)
self
.
writer
.
add_summary
(
summary
,
get_global_step
())
if
self
.
epoch_num
==
1
:
if
len
(
printed_tag
)
!=
len
(
self
.
print_tag
):
logger
.
warn
(
"Tags to print not found in Summary Writer: {}"
.
format
(
", "
.
join
([
k
for
k
in
self
.
print_tag
if
k
not
in
printed_tag
])))
def
_after_train
(
self
):
self
.
writer
.
close
()
...
...
tensorpack/callbacks/dump.py
View file @
ba2e7ff0
...
...
@@ -30,7 +30,7 @@ class DumpParamAsImage(Callback):
self
.
var
=
self
.
graph
.
get_tensor_by_name
(
self
.
var_name
)
self
.
epoch_num
=
0
def
trigger_epoch
(
self
):
def
_
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
val
=
self
.
sess
.
run
(
self
.
var
)
if
self
.
func
is
not
None
:
...
...
tensorpack/callbacks/group.py
View file @
ba2e7ff0
...
...
@@ -73,7 +73,7 @@ class TrainCallbacks(Callback):
else
:
raise
ValueError
(
"Callbacks must contain a SummaryWriter!"
)
def
before_train
(
self
):
def
_
before_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
...
...
@@ -86,7 +86,7 @@ class TrainCallbacks(Callback):
for
cb
in
self
.
cbs
:
cb
.
trigger_step
()
def
trigger_epoch
(
self
):
def
_
trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
...
...
@@ -104,7 +104,7 @@ class TestCallbacks(Callback):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
def
before_train
(
self
):
def
_
before_train
(
self
):
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
...
...
@@ -119,7 +119,7 @@ class TestCallbacks(Callback):
for
cb
in
self
.
cbs
:
cb
.
after_train
()
def
trigger_epoch
(
self
):
def
_
trigger_epoch
(
self
):
if
not
self
.
cbs
:
return
tm
=
CallbackTimeLogger
()
...
...
@@ -157,7 +157,7 @@ class Callbacks(Callback):
self
.
train
=
TrainCallbacks
(
train_cbs
)
self
.
test
=
TestCallbacks
(
test_cbs
)
def
before_train
(
self
):
def
_
before_train
(
self
):
self
.
train
.
before_train
()
self
.
test
.
before_train
()
...
...
@@ -169,7 +169,7 @@ class Callbacks(Callback):
self
.
train
.
trigger_step
()
# test callback don't have trigger_step
def
trigger_epoch
(
self
):
def
_
trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
# TODO test callbacks can be run async?
self
.
test
.
trigger_epoch
()
tensorpack/callbacks/validation_callback.py
View file @
ba2e7ff0
...
...
@@ -68,7 +68,7 @@ class ValidationCallback(PeriodicCallback):
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
self
.
global_step
)
logger
.
info
(
"{}_cost: {:.4f}"
.
format
(
self
.
prefix
,
cost_avg
))
def
_trigger
(
self
):
def
_trigger
_periodic
(
self
):
for
dp
,
outputs
in
self
.
_run_validation
():
pass
...
...
@@ -95,7 +95,7 @@ class ValidationError(ValidationCallback):
def
_get_output_vars
(
self
):
return
[
self
.
wrong_var
]
def
_trigger
(
self
):
def
_trigger
_periodic
(
self
):
err_stat
=
Accuracy
()
for
dp
,
outputs
in
self
.
_run_validation
():
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment