Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4aaf06ca
Commit
4aaf06ca
authored
May 14, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
separate setup_graph & before_train
parent
e69034b5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
39 additions
and
17 deletions
+39
-17
examples/ResNet/svhn-resnet.py
examples/ResNet/svhn-resnet.py
+1
-2
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+14
-4
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+14
-6
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+7
-1
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+0
-2
tensorpack/train/base.py
tensorpack/train/base.py
+2
-1
No files found.
examples/ResNet/svhn-resnet.py
View file @
4aaf06ca
...
@@ -75,7 +75,6 @@ class Model(ModelDesc):
...
@@ -75,7 +75,6 @@ class Model(ModelDesc):
l
=
c2
+
l
l
=
c2
+
l
return
l
return
l
l
=
conv
(
'conv0'
,
image
,
16
,
1
)
l
=
conv
(
'conv0'
,
image
,
16
,
1
)
l
=
BatchNorm
(
'bn0'
,
l
,
is_training
)
l
=
BatchNorm
(
'bn0'
,
l
,
is_training
)
l
=
tf
.
nn
.
relu
(
l
)
l
=
tf
.
nn
.
relu
(
l
)
...
@@ -113,7 +112,7 @@ class Model(ModelDesc):
...
@@ -113,7 +112,7 @@ class Model(ModelDesc):
#wd_cost = regularize_cost('.*/W', l2_regularizer(0.0002), name='regularize_loss')
#wd_cost = regularize_cost('.*/W', l2_regularizer(0.0002), name='regularize_loss')
wd_w
=
tf
.
train
.
exponential_decay
(
0.0001
,
get_global_step_var
(),
wd_w
=
tf
.
train
.
exponential_decay
(
0.0001
,
get_global_step_var
(),
960000
,
0.5
,
True
)
960000
,
0.5
,
True
)
wd_cost
=
wd_w
*
regularize_cost
(
'.*/W'
,
tf
.
nn
.
l2_loss
)
wd_cost
=
tf
.
mul
(
wd_w
,
regularize_cost
(
'.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'wd_cost'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor W
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor W
...
...
tensorpack/callbacks/base.py
View file @
4aaf06ca
...
@@ -28,18 +28,28 @@ class Callback(object):
...
@@ -28,18 +28,28 @@ class Callback(object):
Default is `TrainCallbackType()`
Default is `TrainCallbackType()`
"""
"""
def
before_train
(
self
,
trainer
):
def
before_train
(
self
):
"""
"""
Called before starting iterative training.
Called right before the first iteration.
"""
self
.
_before_train
()
def
_before_train
(
self
):
pass
def
setup_graph
(
self
,
trainer
):
"""
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
:param trainer: a :class:`train.Trainer` instance
:param trainer: a :class:`train.Trainer` instance
"""
"""
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
self
.
_
before_train
()
self
.
_
setup_graph
()
def
_
before_train
(
self
):
def
_
setup_graph
(
self
):
pass
pass
def
after_train
(
self
):
def
after_train
(
self
):
...
...
tensorpack/callbacks/common.py
View file @
4aaf06ca
...
@@ -23,7 +23,7 @@ class ModelSaver(Callback):
...
@@ -23,7 +23,7 @@ class ModelSaver(Callback):
self
.
keep_recent
=
keep_recent
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
self
.
keep_freq
=
keep_freq
def
_
before_train
(
self
):
def
_
setup_graph
(
self
):
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
saver
=
tf
.
train
.
Saver
(
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
ModelSaver
.
_get_vars
(),
var_list
=
ModelSaver
.
_get_vars
(),
...
...
tensorpack/callbacks/group.py
View file @
4aaf06ca
...
@@ -66,7 +66,7 @@ class TestCallbackContext(object):
...
@@ -66,7 +66,7 @@ class TestCallbackContext(object):
self
.
sess
=
None
self
.
sess
=
None
@
contextmanager
@
contextmanager
def
before_train
_context
(
self
,
trainer
):
def
create
_context
(
self
,
trainer
):
if
self
.
sess
is
None
:
if
self
.
sess
is
None
:
with
create_test_session
(
trainer
)
as
sess
:
with
create_test_session
(
trainer
)
as
sess
:
self
.
sess
=
sess
self
.
sess
=
sess
...
@@ -88,7 +88,7 @@ class TestCallbackContext(object):
...
@@ -88,7 +88,7 @@ class TestCallbackContext(object):
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
@
contextmanager
@
contextmanager
def
t
rigger_epoch
_context
(
self
):
def
t
est
_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
yield
...
@@ -110,13 +110,21 @@ class Callbacks(Callback):
...
@@ -110,13 +110,21 @@ class Callbacks(Callback):
self
.
cbs
=
cbs
self
.
cbs
=
cbs
self
.
test_callback_context
=
TestCallbackContext
()
self
.
test_callback_context
=
TestCallbackContext
()
def
_setup_graph
(
self
):
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
cb
.
setup_graph
(
self
.
trainer
)
else
:
with
self
.
test_callback_context
.
create_context
(
self
.
trainer
):
cb
.
setup_graph
(
self
.
trainer
)
def
_before_train
(
self
):
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
cb
.
before_train
(
self
.
trainer
)
cb
.
before_train
()
else
:
else
:
with
self
.
test_callback_context
.
before_train_context
(
self
.
trainer
):
with
self
.
test_callback_context
.
test_context
(
):
cb
.
before_train
(
self
.
trainer
)
cb
.
before_train
()
def
_after_train
(
self
):
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
...
@@ -141,7 +149,7 @@ class Callbacks(Callback):
...
@@ -141,7 +149,7 @@ class Callbacks(Callback):
with
tm
.
timed_callback
(
'restore checkpoint'
):
with
tm
.
timed_callback
(
'restore checkpoint'
):
self
.
test_callback_context
.
restore_checkpoint
()
self
.
test_callback_context
.
restore_checkpoint
()
test_sess_restored
=
True
test_sess_restored
=
True
with
self
.
test_callback_context
.
t
rigger_epoch
_context
(),
\
with
self
.
test_callback_context
.
t
est
_context
(),
\
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
tm
.
log
()
tm
.
log
()
tensorpack/callbacks/param.py
View file @
4aaf06ca
...
@@ -30,7 +30,7 @@ class HyperParamSetter(Callback):
...
@@ -30,7 +30,7 @@ class HyperParamSetter(Callback):
self
.
shape
=
shape
self
.
shape
=
shape
self
.
last_value
=
None
self
.
last_value
=
None
def
_
before_train
(
self
):
def
_
setup_graph
(
self
):
all_vars
=
tf
.
all_variables
()
all_vars
=
tf
.
all_variables
()
for
v
in
all_vars
:
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
if
v
.
name
==
self
.
var_name
:
...
@@ -59,6 +59,12 @@ class HyperParamSetter(Callback):
...
@@ -59,6 +59,12 @@ class HyperParamSetter(Callback):
pass
pass
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
_set_param
()
def
_before_train
(
self
):
self
.
_set_param
()
def
_set_param
(
self
):
v
=
self
.
get_current_value
()
v
=
self
.
get_current_value
()
if
v
is
not
None
:
if
v
is
not
None
:
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
...
...
tensorpack/callbacks/summary.py
View file @
4aaf06ca
...
@@ -70,8 +70,6 @@ class StatHolder(object):
...
@@ -70,8 +70,6 @@ class StatHolder(object):
def
_write_stat
(
self
):
def
_write_stat
(
self
):
tmp_filename
=
self
.
filename
+
'.tmp'
tmp_filename
=
self
.
filename
+
'.tmp'
with
open
(
tmp_filename
,
'w'
)
as
f
:
with
open
(
tmp_filename
,
'w'
)
as
f
:
import
IPython
;
IPython
.
embed
(
config
=
IPython
.
terminal
.
ipapp
.
load_default_config
())
json
.
dump
(
self
.
stat_history
,
f
)
json
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
...
...
tensorpack/train/base.py
View file @
4aaf06ca
...
@@ -81,7 +81,7 @@ class Trainer(object):
...
@@ -81,7 +81,7 @@ class Trainer(object):
self
.
_init_summary
()
self
.
_init_summary
()
get_global_step_var
()
# ensure there is such var, before finalizing the graph
get_global_step_var
()
# ensure there is such var, before finalizing the graph
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
callbacks
.
before_train
(
self
)
callbacks
.
setup_graph
(
self
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
tf
.
get_default_graph
()
.
finalize
()
self
.
_start_all_threads
()
self
.
_start_all_threads
()
...
@@ -91,6 +91,7 @@ class Trainer(object):
...
@@ -91,6 +91,7 @@ class Trainer(object):
self
.
global_step
=
get_global_step
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
()
for
epoch
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
for
epoch
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
'Epoch {}, global_step={}'
.
format
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment