Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
884af444
Commit
884af444
authored
Feb 17, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better train/test callback management
parent
d731cf7b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
98 additions
and
99 deletions
+98
-99
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+9
-3
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+5
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+73
-88
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+3
-3
tensorpack/models/fc.py
tensorpack/models/fc.py
+8
-4
No files found.
tensorpack/callbacks/base.py
View file @
884af444
...
@@ -11,14 +11,20 @@ from abc import abstractmethod, ABCMeta
...
@@ -11,14 +11,20 @@ from abc import abstractmethod, ABCMeta
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallback'
,
'TestCallback'
]
class
TrainCallback
(
object
):
pass
class
TestCallback
(
object
):
pass
class
Callback
(
object
):
class
Callback
(
object
):
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
running_graph
=
'train'
type
=
TrainCallback
()
""" The graph that this callback should run on.
""" The graph that this callback should run on.
Either
'train' or 'test'
Either
TrainCallback or TestCallback
"""
"""
def
before_train
(
self
):
def
before_train
(
self
):
...
...
tensorpack/callbacks/common.py
View file @
884af444
...
@@ -31,4 +31,8 @@ class PeriodicSaver(PeriodicCallback):
...
@@ -31,4 +31,8 @@ class PeriodicSaver(PeriodicCallback):
global_step
=
self
.
global_step
)
global_step
=
self
.
global_step
)
class
MinSaver
(
Callback
):
class
MinSaver
(
Callback
):
pass
def
__init__
(
self
,
monitor_stat
):
self
.
monitor_stat
=
monitor_stat
def
_trigger_epoch
(
self
):
pass
tensorpack/callbacks/group.py
View file @
884af444
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
.base
import
Callback
from
.base
import
Callback
,
TrainCallback
,
TestCallback
from
.summary
import
*
from
.summary
import
*
from
..utils
import
*
from
..utils
import
*
...
@@ -41,6 +41,12 @@ class CallbackTimeLogger(object):
...
@@ -41,6 +41,12 @@ class CallbackTimeLogger(object):
self
.
tot
+=
time
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
self
.
times
.
append
((
name
,
time
))
@
contextmanager
def
timed_callback
(
self
,
name
):
s
=
time
.
time
()
yield
self
.
add
(
name
,
time
.
time
()
-
s
)
def
log
(
self
):
def
log
(
self
):
""" log the time of some heavy callbacks """
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
if
self
.
tot
<
3
:
...
@@ -53,119 +59,98 @@ class CallbackTimeLogger(object):
...
@@ -53,119 +59,98 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}"
.
format
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
' '
.
join
(
msgs
)))
self
.
tot
,
' '
.
join
(
msgs
)))
class
TestCallbackContext
(
object
):
class
TrainCallbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
# put SummaryWriter to the beginning
if
type
(
cb
)
==
SummaryWriter
:
self
.
cbs
.
insert
(
0
,
self
.
cbs
.
pop
(
idx
))
break
else
:
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
self
.
cbs
.
insert
(
0
,
SummaryWriter
())
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
cb
.
after_train
()
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
cb
.
trigger_step
()
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
tm
.
log
()
class
TestCallbacks
(
Callback
):
"""
"""
Hold callbacks to be run in testing graph.
A class holding the context needed for running TestCallback
Will set a context with testing graph and testing session, for
each test-time callback to run
"""
"""
def
__init__
(
self
,
callbacks
):
def
__init__
(
self
):
self
.
cbs
=
callbacks
self
.
sess
=
None
def
_
before_train
(
self
):
def
_
init_test_sess
(
self
):
with
create_test_session
()
as
sess
:
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
self
.
saver
=
tf
.
train
.
Saver
()
for
cb
in
self
.
cbs
:
cb
.
before_train
()
def
_after_train
(
self
):
@
contextmanager
for
cb
in
self
.
cbs
:
def
before_train_context
(
self
):
cb
.
after_train
()
if
self
.
sess
is
None
:
self
.
_init_test_sess
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
def
_trigger_epoch
(
self
):
# TODO also do this for after_train?
if
not
self
.
cbs
:
return
def
restore_checkpoint
(
self
):
tm
=
CallbackTimeLogger
()
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver before any TestCallback?"
)
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
@
contextmanager
def
trigger_epoch_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
s
=
time
.
time
()
yield
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
logger
.
error
(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?"
)
return
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
tm
.
add
(
'restore session'
,
time
.
time
()
-
s
)
for
cb
in
self
.
cbs
:
s
=
time
.
time
()
cb
.
trigger_epoch
()
tm
.
add
(
type
(
cb
)
.
__name__
,
time
.
time
()
-
s
)
tm
.
log
()
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
def
__init__
(
self
,
cbs
):
def
__init__
(
self
,
cbs
):
train_cbs
=
[]
# check type
test_cbs
=
[]
for
cb
in
cbs
:
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
cb
.
running_graph
==
'test'
:
if
not
isinstance
(
cb
.
type
,
(
TrainCallback
,
TestCallback
)):
test_cbs
.
append
(
cb
)
elif
cb
.
running_graph
==
'train'
:
train_cbs
.
append
(
cb
)
else
:
raise
ValueError
(
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
self
.
train
=
TrainCallbacks
(
train_cbs
)
if
test_cbs
:
# ensure a SummaryWriter
self
.
test
=
TestCallbacks
(
test_cbs
)
for
idx
,
cb
in
enumerate
(
cbs
):
if
type
(
cb
)
==
SummaryWriter
:
cbs
.
insert
(
0
,
cbs
.
pop
(
idx
))
break
else
:
else
:
self
.
test
=
None
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
cbs
.
insert
(
0
,
SummaryWriter
())
self
.
cbs
=
cbs
self
.
test_callback_context
=
TestCallbackContext
()
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
train
.
before_train
()
for
cb
in
self
.
cbs
:
if
self
.
test
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
self
.
test
.
before_train
()
cb
.
before_train
()
else
:
with
self
.
test_callback_context
.
before_train_context
():
cb
.
before_train
()
def
_after_train
(
self
):
def
_after_train
(
self
):
self
.
train
.
after_train
()
for
cb
in
self
.
cbs
:
if
self
.
test
:
cb
.
after_train
()
self
.
test
.
after_train
()
logger
.
writer
.
close
()
logger
.
writer
.
close
()
def
trigger_step
(
self
):
def
trigger_step
(
self
):
self
.
train
.
trigger_step
()
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
cb
.
trigger_step
()
# test callback don't have trigger_step
# test callback don't have trigger_step
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
tm
=
CallbackTimeLogger
()
if
self
.
test
:
self
.
test
.
trigger_epoch
()
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
with
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
else
:
if
not
test_sess_restored
:
with
tm
.
timed_callback
(
'restore checkpoint'
):
self
.
test_callback_context
.
restore_checkpoint
()
test_sess_restored
=
True
with
self
.
test_callback_context
.
trigger_epoch_context
(),
\
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
tm
.
log
()
logger
.
writer
.
flush
()
logger
.
writer
.
flush
()
logger
.
stat_holder
.
finalize
()
logger
.
stat_holder
.
finalize
()
tensorpack/callbacks/validation_callback.py
View file @
884af444
...
@@ -10,12 +10,12 @@ from tqdm import tqdm
...
@@ -10,12 +10,12 @@ from tqdm import tqdm
from
..utils
import
*
from
..utils
import
*
from
..utils.stat
import
*
from
..utils.stat
import
*
from
..utils.summary
import
*
from
..utils.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
from
.base
import
PeriodicCallback
,
Callback
,
TestCallback
__all__
=
[
'ValidationError'
,
'ValidationCallback'
]
__all__
=
[
'ValidationError'
,
'ValidationCallback'
]
class
ValidationCallback
(
PeriodicCallback
):
class
ValidationCallback
(
PeriodicCallback
):
running_graph
=
'test'
type
=
TestCallback
()
"""
"""
Basic routine for validation callbacks.
Basic routine for validation callbacks.
"""
"""
...
@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
...
@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
output_vars
=
self
.
_get_output_vars
()
output_vars
=
self
.
_get_output_vars
()
output_vars
.
append
(
self
.
cost_var
)
output_vars
.
append
(
self
.
cost_var
)
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
with
tqdm
(
total
=
self
.
ds
.
size
()
,
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
itertools
.
izip
(
self
.
input_vars
,
dp
))
feed
=
dict
(
itertools
.
izip
(
self
.
input_vars
,
dp
))
...
...
tensorpack/models/fc.py
View file @
884af444
...
@@ -12,7 +12,9 @@ from ..utils.symbolic_functions import *
...
@@ -12,7 +12,9 @@ from ..utils.symbolic_functions import *
__all__
=
[
'FullyConnected'
]
__all__
=
[
'FullyConnected'
]
@
layer_register
(
summary_activation
=
True
)
@
layer_register
(
summary_activation
=
True
)
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
):
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
,
use_bias
=
True
):
x
=
batch_flatten
(
x
)
x
=
batch_flatten
(
x
)
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
...
@@ -20,9 +22,11 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
...
@@ -20,9 +22,11 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
W_init
=
tf
.
uniform_unit_scaling_initializer
()
W_init
=
tf
.
uniform_unit_scaling_initializer
()
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
(
0.0
)
b_init
=
tf
.
constant_initializer
()
with
tf
.
device
(
'/cpu:0'
):
with
tf
.
device
(
'/cpu:0'
):
W
=
tf
.
get_variable
(
'W'
,
[
in_dim
,
out_dim
],
initializer
=
W_init
)
W
=
tf
.
get_variable
(
'W'
,
[
in_dim
,
out_dim
],
initializer
=
W_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
if
use_bias
:
return
nl
(
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
),
name
=
tf
.
get_variable_scope
()
.
name
+
'_output'
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
return
nl
(
prod
,
name
=
tf
.
get_variable_scope
()
.
name
+
'_output'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment