Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fd9edc3b
Commit
fd9edc3b
authored
Aug 04, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove unnecessary var_strategy checks in trainers.
parent
9b707d91
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
15 deletions
+10
-15
docs/tutorial/extend/trainer.md
docs/tutorial/extend/trainer.md
+4
-3
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+4
-2
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+0
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-8
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-1
No files found.
docs/tutorial/extend/trainer.md
View file @
fd9edc3b
...
@@ -8,7 +8,7 @@ or write an issue to see if there is a better solution than creating new trainer
...
@@ -8,7 +8,7 @@ or write an issue to see if there is a better solution than creating new trainer
For certain tasks, you do need a new trainer.
For certain tasks, you do need a new trainer.
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
The existing common trainers do two things:
The existing common trainers
all
do two things:
1.
Setup the graph and input pipeline, from
`TrainConfig`
.
1.
Setup the graph and input pipeline, from
`TrainConfig`
.
2.
Minimize
`model.cost`
in each iteration.
2.
Minimize
`model.cost`
in each iteration.
...
@@ -16,11 +16,12 @@ But you can customize it by using the base `Trainer` class.
...
@@ -16,11 +16,12 @@ But you can customize it by using the base `Trainer` class.
*
To customize the graph:
*
To customize the graph:
Create the graph, add any tensors and ops either before creating the trainer or inside
`Trainer.__init__`
.
Add any tensors and ops you like, either before creating the trainer or inside
`Trainer.__init__`
.
In this case you don't need to set model/data in
`TrainConfig`
any more.
*
Two ways to customize the iteration:
*
Two ways to customize the iteration:
1. Set `Trainer.train_op`. This op will be run by default.
1. Set `Trainer.train_op`. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method. This way you can
run more ops in one iteration
.
2. Subclass `Trainer` and override the `run_step()` method. This way you can
do something more than running an op
.
There are several different
[
GAN trainers
](
../../examples/GAN/GAN.py
)
for reference.
There are several different
[
GAN trainers
](
../../examples/GAN/GAN.py
)
for reference.
tensorpack/models/batch_norm.py
View file @
fd9edc3b
...
@@ -213,11 +213,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
...
@@ -213,11 +213,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
xn
=
layer
.
apply
(
x
,
training
=
ctx
.
is_training
,
scope
=
tf
.
get_variable_scope
())
xn
=
layer
.
apply
(
x
,
training
=
ctx
.
is_training
,
scope
=
tf
.
get_variable_scope
())
if
ctx
.
has_own_variables
:
if
ctx
.
has_own_variables
:
# only apply update in this case
# Only apply update in this case.
# Add these EMA to model_variables so that they will be synced
# properly by replicated trainers.
for
v
in
layer
.
non_trainable_variables
:
for
v
in
layer
.
non_trainable_variables
:
add_model_variable
(
v
)
add_model_variable
(
v
)
else
:
else
:
#
don't need update if we are sharing variables from an old
tower
#
Don't need update if we are sharing variables from an existing
tower
restore_collection
(
coll_bk
)
restore_collection
(
coll_bk
)
if
ndims
==
2
:
if
ndims
==
2
:
...
...
tensorpack/train/distributed.py
View file @
fd9edc3b
...
@@ -199,7 +199,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
...
@@ -199,7 +199,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
self
.
model
,
self
.
_input_source
),
self
.
model
,
self
.
_input_source
),
devices
=
self
.
raw_devices
,
devices
=
self
.
raw_devices
,
var_strategy
=
'replicated'
,
vs_names
=
[
True
]
*
self
.
config
.
nr_tower
)
# open vs at each tower
vs_names
=
[
True
]
*
self
.
config
.
nr_tower
)
# open vs at each tower
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
...
...
tensorpack/train/multigpu.py
View file @
fd9edc3b
...
@@ -50,14 +50,13 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
...
@@ -50,14 +50,13 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
@
staticmethod
@
staticmethod
def
build_on_multi_tower
(
def
build_on_multi_tower
(
towers
,
func
,
towers
,
func
,
devices
=
None
,
var_strategy
=
'shared'
,
devices
=
None
,
use_vs
=
None
):
use_vs
=
None
):
"""
"""
Args:
Args:
towers: list of gpu relative ids
towers: list of gpu relative ids
func: a lambda to be called inside each tower
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``.
devices: a list of devices to be used. By default will use GPUs in ``towers``.
var_strategy (str): 'shared' or 'replicated'
use_vs (list[bool]): list of use_vs to passed to TowerContext
use_vs (list[bool]): list of use_vs to passed to TowerContext
Returns:
Returns:
...
@@ -73,11 +72,6 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
...
@@ -73,11 +72,6 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
tower_names
=
[
'tower{}'
.
format
(
idx
)
for
idx
in
range
(
len
(
towers
))]
tower_names
=
[
'tower{}'
.
format
(
idx
)
for
idx
in
range
(
len
(
towers
))]
keys_to_freeze
=
TOWER_FREEZE_KEYS
[:]
keys_to_freeze
=
TOWER_FREEZE_KEYS
[:]
if
var_strategy
==
'replicated'
:
# TODO ugly
logger
.
info
(
"In replicated mode, UPDATE_OPS from all GPUs will be run."
)
keys_to_freeze
.
remove
(
tf
.
GraphKeys
.
UPDATE_OPS
)
else
:
assert
use_vs
is
None
if
use_vs
is
None
:
if
use_vs
is
None
:
use_vs
=
[
False
]
*
len
(
towers
)
use_vs
=
[
False
]
*
len
(
towers
)
assert
len
(
use_vs
)
==
len
(
towers
)
assert
len
(
use_vs
)
==
len
(
towers
)
...
@@ -308,7 +302,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
...
@@ -308,7 +302,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
tower
,
tower
,
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
model
,
input
),
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
model
,
input
),
var_strategy
=
'replicated'
,
# use no variable scope for the first tower
# use no variable scope for the first tower
use_vs
=
[
False
]
+
[
True
]
*
(
len
(
tower
)
-
1
))
use_vs
=
[
False
]
+
[
True
]
*
(
len
(
tower
)
-
1
))
grads
=
SyncMultiGPUTrainerReplicated
.
_allreduce_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainerReplicated
.
_allreduce_grads
(
grad_list
)
...
...
tensorpack/utils/naming.py
View file @
fd9edc3b
...
@@ -15,7 +15,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
...
@@ -15,7 +15,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_OPS_KEY
]
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_OPS_KEY
]
TOWER_FREEZE_KEYS
=
SUMMARY_BACKUP_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]
TOWER_FREEZE_KEYS
=
SUMMARY_BACKUP_KEYS
# export all upper case variables
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
all_local_names
=
locals
()
.
keys
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment