Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b6370d50
Commit
b6370d50
authored
Apr 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix multigpu naming problem
parent
c08297ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
2 deletions
+22
-2
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+17
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+3
-0
No files found.
tensorpack/callbacks/common.py
View file @
b6370d50
...
@@ -20,9 +20,26 @@ class PeriodicSaver(PeriodicCallback):
...
@@ -20,9 +20,26 @@ class PeriodicSaver(PeriodicCallback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
saver
=
tf
.
train
.
Saver
(
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
self
.
_get_vars
(),
max_to_keep
=
self
.
keep_recent
,
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_get_vars
(
self
):
vars
=
tf
.
all_variables
()
var_dict
=
{}
for
v
in
vars
:
name
=
v
.
op
.
name
if
re
.
match
(
'tower[1-9]'
,
name
):
logger
.
info
(
"Skip {} when saving model."
.
format
(
name
))
continue
if
'tower0/'
in
name
:
new_name
=
name
.
replace
(
'tower0/'
,
''
)
logger
.
info
(
"{} renamed to {} when saving model."
.
format
(
name
,
new_name
))
name
=
new_name
var_dict
[
name
]
=
v
return
var_dict
def
_trigger_periodic
(
self
):
def
_trigger_periodic
(
self
):
self
.
saver
.
save
(
self
.
saver
.
save
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
...
...
tensorpack/models/batch_norm.py
View file @
b6370d50
...
@@ -40,9 +40,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5):
...
@@ -40,9 +40,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5):
initializer
=
tf
.
constant_initializer
(
1.0
))
initializer
=
tf
.
constant_initializer
(
1.0
))
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
name
=
'moments'
,
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
else
:
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
,
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
...
...
tensorpack/train/trainer.py
View file @
b6370d50
...
@@ -118,6 +118,8 @@ class QueueInputTrainer(Trainer):
...
@@ -118,6 +118,8 @@ class QueueInputTrainer(Trainer):
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
model_inputs
=
get_model_inputs
()
model_inputs
=
get_model_inputs
()
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
if
i
==
0
:
cost_var_t0
=
cost_var
grad_list
.
append
(
grad_list
.
append
(
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
))
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
))
...
@@ -129,6 +131,7 @@ class QueueInputTrainer(Trainer):
...
@@ -129,6 +131,7 @@ class QueueInputTrainer(Trainer):
del
tf
.
get_collection
(
k
)[:]
del
tf
.
get_collection
(
k
)[:]
tf
.
get_collection
(
k
)
.
extend
(
kept_summaries
[
k
])
tf
.
get_collection
(
k
)
.
extend
(
kept_summaries
[
k
])
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
cost_var
=
cost_var_t0
else
:
else
:
model_inputs
=
get_model_inputs
()
model_inputs
=
get_model_inputs
()
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment