Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
884af444
Commit
884af444
authored
Feb 17, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better train/test callback management
parent
d731cf7b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
98 additions
and
99 deletions
+98
-99
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+9
-3
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+5
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+73
-88
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+3
-3
tensorpack/models/fc.py
tensorpack/models/fc.py
+8
-4
No files found.
tensorpack/callbacks/base.py
View file @
884af444
...
...
@@ -11,14 +11,20 @@ from abc import abstractmethod, ABCMeta
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallback'
,
'TestCallback'
]
class
TrainCallback
(
object
):
pass
class
TestCallback
(
object
):
pass
class
Callback
(
object
):
__metaclass__
=
ABCMeta
running_graph
=
'train'
type
=
TrainCallback
()
""" The graph that this callback should run on.
Either
'train' or 'test'
Either
TrainCallback or TestCallback
"""
def
before_train
(
self
):
...
...
tensorpack/callbacks/common.py
View file @
884af444
...
...
@@ -31,4 +31,8 @@ class PeriodicSaver(PeriodicCallback):
global_step
=
self
.
global_step
)
class
MinSaver
(
Callback
):
pass
def
__init__
(
self
,
monitor_stat
):
self
.
monitor_stat
=
monitor_stat
def
_trigger_epoch
(
self
):
pass
tensorpack/callbacks/group.py
View file @
884af444
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
.base
import
Callback
from
.base
import
Callback
,
TrainCallback
,
TestCallback
from
.summary
import
*
from
..utils
import
*
...
...
@@ -41,6 +41,12 @@ class CallbackTimeLogger(object):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
@
contextmanager
def
timed_callback
(
self
,
name
):
s
=
time
.
time
()
yield
self
.
add
(
name
,
time
.
time
()
-
s
)
def
log
(
self
):
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
...
...
@@ -53,119 +59,98 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
' '
.
join
(
msgs
)))
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
# put SummaryWriter to the beginning
if
type
(
cb
)
==
SummaryWriter
:
self
.
cbs
.
insert
(
0
,
self
.
cbs
.
pop
(
idx
))
break
else
:
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
self
.
cbs
.
insert
(
0
,
SummaryWriter
())
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
after_train
()
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
cb
.
trigger_step
()
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
tm
.
log
()
class
TestCallbacks
(
Callback
):
class
TestCallbackContext
(
object
):
"""
Hold callbacks to be run in testing graph.
Will set a context with testing graph and testing session, for
each test-time callback to run
A class holding the context needed for running TestCallback
"""
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
def
__init__
(
self
):
self
.
sess
=
None
def
_
before_train
(
self
):
def
_
init_test_sess
(
self
):
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
after_train
()
@
contextmanager
def
before_train_context
(
self
):
if
self
.
sess
is
None
:
self
.
_init_test_sess
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
def
_trigger_epoch
(
self
):
if
not
self
.
cbs
:
return
tm
=
CallbackTimeLogger
()
# TODO also do this for after_train?
def
restore_checkpoint
(
self
):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver before any TestCallback?"
)
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
@
contextmanager
def
trigger_epoch_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
s
=
time
.
time
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
logger
.
error
(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?"
)
return
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
tm
.
log
()
yield
class
Callbacks
(
Callback
):
def
__init__
(
self
,
cbs
):
train_cbs
=
[]
test_cbs
=
[]
# check type
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
cb
.
running_graph
==
'test'
:
test_cbs
.
append
(
cb
)
elif
cb
.
running_graph
==
'train'
:
train_cbs
.
append
(
cb
)
else
:
if
not
isinstance
(
cb
.
type
,
(
TrainCallback
,
TestCallback
)):
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
self
.
train
=
TrainCallbacks
(
train_cbs
)
if
test_cbs
:
self
.
test
=
TestCallbacks
(
test_cbs
)
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
# ensure a SummaryWriter
for
idx
,
cb
in
enumerate
(
cbs
):
if
type
(
cb
)
==
SummaryWriter
:
cbs
.
insert
(
0
,
cbs
.
pop
(
idx
))
break
else
:
self
.
test
=
None
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
cbs
.
insert
(
0
,
SummaryWriter
())
self
.
cbs
=
cbs
self
.
test_callback_context
=
TestCallbackContext
()
def
_before_train
(
self
):
self
.
train
.
before_train
()
if
self
.
test
:
self
.
test
.
before_train
()
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
cb
.
before_train
()
else
:
with
self
.
test_callback_context
.
before_train_context
():
cb
.
before_train
()
def
_after_train
(
self
):
self
.
train
.
after_train
()
if
self
.
test
:
self
.
test
.
after_train
()
for
cb
in
self
.
cbs
:
cb
.
after_train
()
logger
.
writer
.
close
()
def
trigger_step
(
self
):
self
.
train
.
trigger_step
()
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
cb
.
trigger_step
()
# test callback don't have trigger_step
def
_trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
if
self
.
test
:
self
.
test
.
trigger_epoch
()
tm
=
CallbackTimeLogger
()
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
with
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
else
:
if
not
test_sess_restored
:
with
tm
.
timed_callback
(
'restore checkpoint'
):
self
.
test_callback_context
.
restore_checkpoint
()
test_sess_restored
=
True
with
self
.
test_callback_context
.
trigger_epoch_context
(),
\
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
tm
.
log
()
logger
.
writer
.
flush
()
logger
.
stat_holder
.
finalize
()
tensorpack/callbacks/validation_callback.py
View file @
884af444
...
...
@@ -10,12 +10,12 @@ from tqdm import tqdm
from
..utils
import
*
from
..utils.stat
import
*
from
..utils.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
from
.base
import
PeriodicCallback
,
Callback
,
TestCallback
__all__
=
[
'ValidationError'
,
'ValidationCallback'
]
class
ValidationCallback
(
PeriodicCallback
):
running_graph
=
'test'
type
=
TestCallback
()
"""
Basic routine for validation callbacks.
"""
...
...
@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
output_vars
=
self
.
_get_output_vars
()
output_vars
.
append
(
self
.
cost_var
)
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
with
tqdm
(
total
=
self
.
ds
.
size
()
,
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
itertools
.
izip
(
self
.
input_vars
,
dp
))
...
...
tensorpack/models/fc.py
View file @
884af444
...
...
@@ -12,7 +12,9 @@ from ..utils.symbolic_functions import *
__all__
=
[
'FullyConnected'
]
@
layer_register
(
summary_activation
=
True
)
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
):
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
,
use_bias
=
True
):
x
=
batch_flatten
(
x
)
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
...
...
@@ -20,9 +22,11 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
W_init
=
tf
.
uniform_unit_scaling_initializer
()
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
(
0.0
)
b_init
=
tf
.
constant_initializer
()
with
tf
.
device
(
'/cpu:0'
):
W
=
tf
.
get_variable
(
'W'
,
[
in_dim
,
out_dim
],
initializer
=
W_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
return
nl
(
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
),
name
=
tf
.
get_variable_scope
()
.
name
+
'_output'
)
if
use_bias
:
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
return
nl
(
prod
,
name
=
tf
.
get_variable_scope
()
.
name
+
'_output'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment