Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
76fe1b6b
Commit
76fe1b6b
authored
Apr 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update cifar number & fix multigpu restore bug
parent
da3da39d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
66 additions
and
16 deletions
+66
-16
examples/ResNet/cifar10_resnet.py
examples/ResNet/cifar10_resnet.py
+1
-1
examples/ResNet/svhn_resnet.py
examples/ResNet/svhn_resnet.py
+6
-6
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+2
-2
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+3
-2
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+1
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+52
-4
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
No files found.
examples/ResNet/cifar10_resnet.py
View file @
76fe1b6b
...
@@ -24,7 +24,7 @@ This implementation uses the variants proposed in:
...
@@ -24,7 +24,7 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027
Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results for
I can reproduce the results for
n=5, about 7.
6
%
val error
n=5, about 7.
2
%
val error after 93k step with 2 TitanX (6.8it/s)
n=18, about 6.05
%
val error after 62k step with 2 TitanX (about 10hr)
n=18, about 6.05
%
val error after 62k step with 2 TitanX (about 10hr)
n=30: a 182-layer network, about 5.5
%
val error after 51k step with 2 GPUs
n=30: a 182-layer network, about 5.5
%
val error after 51k step with 2 GPUs
This model uses the whole training set instead of a 95:5 train-val split.
This model uses the whole training set instead of a 95:5 train-val split.
...
...
examples/ResNet/svhn_resnet.py
View file @
76fe1b6b
...
@@ -20,8 +20,9 @@ from tensorpack.dataflow import imgaug
...
@@ -20,8 +20,9 @@ from tensorpack.dataflow import imgaug
"""
"""
Reach 1.9
%
validation error after 90 epochs, with 2 GPUs.
ResNet-110 for SVHN Digit Classification.
You might need to adjust learning rate schedule when running with 1 GPU.
Reach 1.9
%
validation error after 90 epochs, with 2 TitanX xxhr, 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""
"""
BATCH_SIZE
=
128
BATCH_SIZE
=
128
...
@@ -98,8 +99,7 @@ class Model(ModelDesc):
...
@@ -98,8 +99,7 @@ class Model(ModelDesc):
logits
=
FullyConnected
(
'linear'
,
l
,
out_dim
=
10
,
nl
=
tf
.
identity
)
logits
=
FullyConnected
(
'linear'
,
l
,
out_dim
=
10
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
y
=
one_hot
(
label
,
10
)
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
label
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
y
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
...
@@ -167,8 +167,8 @@ def get_config():
...
@@ -167,8 +167,8 @@ def get_config():
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
Periodic
Saver
(),
Model
Saver
(),
ValidationError
(
dataset_test
,
prefix
=
'test
'
),
ClassificationError
(
dataset_test
,
prefix
=
'validation
'
),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
1
,
0.1
),
(
20
,
0.01
),
(
33
,
0.001
),
(
60
,
0.0001
)])
[(
1
,
0.1
),
(
20
,
0.01
),
(
33
,
0.001
),
(
60
,
0.0001
)])
]),
]),
...
...
examples/cifar10_convnet.py
View file @
76fe1b6b
...
@@ -114,7 +114,7 @@ def get_config():
...
@@ -114,7 +114,7 @@ def get_config():
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-2
,
learning_rate
=
1e-2
,
global_step
=
get_global_step_var
(),
global_step
=
get_global_step_var
(),
decay_steps
=
step_per_epoch
*
30
if
nr_gpu
==
1
else
20
,
decay_steps
=
step_per_epoch
*
(
30
if
nr_gpu
==
1
else
20
)
,
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
...
@@ -129,7 +129,7 @@ def get_config():
...
@@ -129,7 +129,7 @@ def get_config():
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
3
,
max_epoch
=
20
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/callbacks/common.py
View file @
76fe1b6b
...
@@ -26,11 +26,12 @@ class ModelSaver(Callback):
...
@@ -26,11 +26,12 @@ class ModelSaver(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
saver
=
tf
.
train
.
Saver
(
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
self
.
_get_vars
(),
var_list
=
ModelSaver
.
_get_vars
(),
max_to_keep
=
self
.
keep_recent
,
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_get_vars
(
self
):
@
staticmethod
def
_get_vars
():
vars
=
tf
.
all_variables
()
vars
=
tf
.
all_variables
()
var_dict
=
{}
var_dict
=
{}
for
v
in
vars
:
for
v
in
vars
:
...
...
tensorpack/callbacks/group.py
View file @
76fe1b6b
...
@@ -70,6 +70,7 @@ class TestCallbackContext(object):
...
@@ -70,6 +70,7 @@ class TestCallbackContext(object):
with
create_test_session
(
trainer
)
as
sess
:
with
create_test_session
(
trainer
)
as
sess
:
self
.
sess
=
sess
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
graph
=
sess
.
graph
# no tower in test graph. just keep it as what it is
self
.
saver
=
tf
.
train
.
Saver
()
self
.
saver
=
tf
.
train
.
Saver
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
yield
...
...
tensorpack/tfutils/sessinit.py
View file @
76fe1b6b
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
import
os
import
os
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
import
numpy
as
np
import
numpy
as
np
from
collections
import
defaultdict
import
re
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
...
@@ -38,7 +40,7 @@ class NewSession(SessionInit):
...
@@ -38,7 +40,7 @@ class NewSession(SessionInit):
class
SaverRestore
(
SessionInit
):
class
SaverRestore
(
SessionInit
):
"""
"""
Restore an old model saved by `
tf.
Saver`.
Restore an old model saved by `
Model
Saver`.
"""
"""
def
__init__
(
self
,
model_path
):
def
__init__
(
self
,
model_path
):
"""
"""
...
@@ -52,14 +54,60 @@ class SaverRestore(SessionInit):
...
@@ -52,14 +54,60 @@ class SaverRestore(SessionInit):
self
.
set_path
(
model_path
)
self
.
set_path
(
model_path
)
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
saver
=
tf
.
train
.
Saver
()
saver
.
restore
(
sess
,
self
.
path
)
logger
.
info
(
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
self
.
path
))
"Restoring checkpoint from {}."
.
format
(
self
.
path
))
sess
.
run
(
tf
.
initialize_all_variables
())
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
vars_map
=
SaverRestore
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
)
saver
.
restore
(
sess
,
self
.
path
)
def
set_path
(
self
,
model_path
):
def
set_path
(
self
,
model_path
):
self
.
path
=
model_path
self
.
path
=
model_path
@
staticmethod
def
_produce_restore_dict
(
vars_multimap
):
"""
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict.
"""
while
len
(
vars_multimap
):
ret
=
{}
for
k
in
vars_multimap
.
keys
():
v
=
vars_multimap
[
k
]
ret
[
k
]
=
v
[
-
1
]
del
v
[
-
1
]
if
not
len
(
v
):
del
vars_multimap
[
k
]
yield
ret
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
return
set
(
reader
.
GetVariableToShapeMap
()
.
keys
())
@
staticmethod
def
_get_vars_to_restore_multimap
(
vars_available
):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
"""
# TODO warn if some variable in checkpoint is not used
vars_to_restore
=
tf
.
all_variables
()
var_dict
=
defaultdict
(
list
)
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
name
)
name
=
new_name
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
else
:
logger
.
warn
(
"Param {} not found in checkpoint! Will not restore."
.
format
(
v
.
op
.
name
))
return
var_dict
class
ParamRestore
(
SessionInit
):
class
ParamRestore
(
SessionInit
):
"""
"""
Restore trainable variables from a dictionary.
Restore trainable variables from a dictionary.
...
...
tensorpack/train/base.py
View file @
76fe1b6b
...
@@ -83,7 +83,7 @@ class Trainer(object):
...
@@ -83,7 +83,7 @@ class Trainer(object):
self
.
global_step
=
get_global_step
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
for
epoch
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
):
for
epoch
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment