Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
efe3dfb5
Commit
efe3dfb5
authored
Jul 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove the many levels of Trainer herarchy
parent
079eb3a9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
20 deletions
+30
-20
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+3
-3
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+9
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+18
-9
No files found.
tensorpack/train/distributed.py
View file @
efe3dfb5
...
@@ -8,7 +8,6 @@ import os
...
@@ -8,7 +8,6 @@ import os
from
six.moves
import
range
from
six.moves
import
range
from
..utils
import
logger
from
..utils
import
logger
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.multigpu
import
MultiGPUTrainerBase
from
.multigpu
import
MultiGPUTrainerBase
from
..callbacks
import
RunOp
from
..callbacks
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
NewSessionCreator
...
@@ -35,7 +34,7 @@ class OverrideToLocalVariable(object):
...
@@ -35,7 +34,7 @@ class OverrideToLocalVariable(object):
return
getter
(
name
,
*
args
,
**
kwargs
)
return
getter
(
name
,
*
args
,
**
kwargs
)
class
DistributedReplicatedTrainer
(
SingleCostFeedfreeTrainer
):
class
DistributedReplicatedTrainer
(
MultiGPUTrainerBase
):
"""
"""
Distributed replicated training.
Distributed replicated training.
Each worker process builds the same model on one or more GPUs.
Each worker process builds the same model on one or more GPUs.
...
@@ -191,7 +190,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -191,7 +190,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
# Ngpu * Nvar * 2
# Ngpu * Nvar * 2
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
self
.
model
,
self
.
_input_source
),
devices
=
self
.
raw_devices
,
devices
=
self
.
raw_devices
,
var_strategy
=
'replicated'
,
var_strategy
=
'replicated'
,
vs_names
=
None
)
# use the default vs names
vs_names
=
None
)
# use the default vs names
...
...
tensorpack/train/feedfree.py
View file @
efe3dfb5
...
@@ -20,10 +20,8 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -20,10 +20,8 @@ class FeedfreeTrainerBase(Trainer):
Expect ``self.data`` to be a :class:`FeedfreeInput`.
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
"""
# TODO deprecated
@
deprecated
(
"Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)"
)
def
build_train_tower
(
self
):
def
build_train_tower
(
self
):
logger
.
warn
(
"build_train_tower() was deprecated! Please build the graph "
"yourself, e.g. by self.model.build_graph(self._input_source)"
)
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
self
.
model
.
build_graph
(
self
.
_input_source
)
self
.
model
.
build_graph
(
self
.
_input_source
)
...
@@ -36,16 +34,20 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -36,16 +34,20 @@ class FeedfreeTrainerBase(Trainer):
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
#
TODO Kept for now for back-compat
#
deprecated
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
""" A feedfree Trainer which assumes a single cost. """
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
SingleCostFeedfreeTrainer
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
logger
.
warn
(
"SingleCostFeedfreeTrainer was deprecated!"
)
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
""" get the cost and gradient"""
self
.
model
.
build_graph
(
self
.
_input_source
)
self
.
model
.
build_graph
(
self
.
_input_source
)
return
self
.
model
.
get_cost_and_grad
()
return
self
.
model
.
get_cost_and_grad
()
@
deprecated
(
"Use SimpleTrainer with config.data i
nstead!
"
)
@
deprecated
(
"Use SimpleTrainer with config.data i
s the same!"
,
"2017-09-13
"
)
def
SimpleFeedfreeTrainer
(
config
):
def
SimpleFeedfreeTrainer
(
config
):
assert
isinstance
(
config
.
data
,
FeedfreeInput
),
config
.
data
assert
isinstance
(
config
.
data
,
FeedfreeInput
),
config
.
data
return
SimpleTrainer
(
config
)
return
SimpleTrainer
(
config
)
...
@@ -53,9 +55,8 @@ def SimpleFeedfreeTrainer(config):
...
@@ -53,9 +55,8 @@ def SimpleFeedfreeTrainer(config):
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
"""
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
:class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
...
...
tensorpack/train/multigpu.py
View file @
efe3dfb5
...
@@ -15,9 +15,8 @@ from ..tfutils.collection import backup_collection, restore_collection
...
@@ -15,9 +15,8 @@ from ..tfutils.collection import backup_collection, restore_collection
from
..tfutils.gradproc
import
ScaleGradient
from
..tfutils.gradproc
import
ScaleGradient
from
..callbacks.graph
import
RunOp
from
..callbacks.graph
import
RunOp
from
.base
import
Trainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
..graph_builder.input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
..graph_builder.input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
.feedfree
import
FeedfreeTrainerBase
__all__
=
[
'MultiGPUTrainerBase'
,
'SyncMultiGPUTrainer'
,
__all__
=
[
'MultiGPUTrainerBase'
,
'SyncMultiGPUTrainer'
,
'AsyncMultiGPUTrainer'
,
'LeastLoadedDeviceSetter'
,
'AsyncMultiGPUTrainer'
,
'LeastLoadedDeviceSetter'
,
...
@@ -44,7 +43,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
...
@@ -44,7 +43,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
config
.
data
=
StagingInputWrapper
(
config
.
data
,
devices
)
config
.
data
=
StagingInputWrapper
(
config
.
data
,
devices
)
class
MultiGPUTrainerBase
(
Trainer
):
class
MultiGPUTrainerBase
(
FeedfreeTrainerBase
):
""" Base class for multi-gpu training"""
""" Base class for multi-gpu training"""
@
staticmethod
@
staticmethod
def
build_on_multi_tower
(
def
build_on_multi_tower
(
...
@@ -116,6 +115,11 @@ class MultiGPUTrainerBase(Trainer):
...
@@ -116,6 +115,11 @@ class MultiGPUTrainerBase(Trainer):
nvars
=
[
len
(
k
)
for
k
in
grad_list
]
nvars
=
[
len
(
k
)
for
k
in
grad_list
]
assert
len
(
set
(
nvars
))
==
1
,
"Number of gradients from each tower is different! "
+
str
(
nvars
)
assert
len
(
set
(
nvars
))
==
1
,
"Number of gradients from each tower is different! "
+
str
(
nvars
)
@
staticmethod
def
_build_graph_get_grads
(
model
,
input
):
model
.
build_graph
(
input
)
return
model
.
get_cost_and_grad
()[
1
]
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class
LeastLoadedDeviceSetter
(
object
):
class
LeastLoadedDeviceSetter
(
object
):
...
@@ -148,7 +152,7 @@ class LeastLoadedDeviceSetter(object):
...
@@ -148,7 +152,7 @@ class LeastLoadedDeviceSetter(object):
return
sanitize_name
(
device_name
)
return
sanitize_name
(
device_name
)
class
SyncMultiGPUTrainerParameterServer
(
MultiGPUTrainerBase
,
SingleCostFeedfreeTrainer
):
class
SyncMultiGPUTrainerParameterServer
(
MultiGPUTrainerBase
):
"""
"""
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored across all
from each tower, averages them and update to variables stored across all
...
@@ -199,7 +203,9 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
...
@@ -199,7 +203,9 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
worker_device
=
d
,
ps_device
=
'/cpu:0'
,
ps_tasks
=
1
)
for
d
in
raw_devices
]
worker_device
=
d
,
ps_device
=
'/cpu:0'
,
ps_tasks
=
1
)
for
d
in
raw_devices
]
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
devices
)
self
.
config
.
tower
,
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
self
.
model
,
self
.
_input_source
),
devices
)
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
# debug tower performance (without update):
# debug tower performance (without update):
...
@@ -223,7 +229,7 @@ def SyncMultiGPUTrainer(config):
...
@@ -223,7 +229,7 @@ def SyncMultiGPUTrainer(config):
return
SyncMultiGPUTrainerParameterServer
(
config
,
ps_device
=
'gpu'
)
return
SyncMultiGPUTrainerParameterServer
(
config
,
ps_device
=
'gpu'
)
class
SyncMultiGPUTrainerReplicated
(
MultiGPUTrainerBase
,
SingleCostFeedfreeTrainer
):
class
SyncMultiGPUTrainerReplicated
(
MultiGPUTrainerBase
):
"""
"""
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced.
whole model. Each gradient update is broadcast and synced.
...
@@ -266,7 +272,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
...
@@ -266,7 +272,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
self
.
model
,
self
.
_input_source
),
var_strategy
=
'replicated'
,
var_strategy
=
'replicated'
,
# use no variable scope for the first tower
# use no variable scope for the first tower
vs_names
=
[
''
]
+
[
None
]
*
(
self
.
config
.
nr_tower
-
1
))
vs_names
=
[
''
]
+
[
None
]
*
(
self
.
config
.
nr_tower
-
1
))
...
@@ -308,7 +315,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
...
@@ -308,7 +315,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_tower0'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_tower0'
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainerBase
,
SingleCostFeedfreeTrainer
):
class
AsyncMultiGPUTrainer
(
MultiGPUTrainerBase
):
"""
"""
A multi-tower multi-GPU trainer where each tower independently
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without averaging the gradient.
asynchronously updates the model without averaging the gradient.
...
@@ -330,7 +337,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
...
@@ -330,7 +337,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
self
.
config
.
tower
]
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
self
.
config
.
tower
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
devices
)
self
.
config
.
tower
,
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
self
.
model
,
self
.
_input_source
),
devices
)
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
if
self
.
_scale_gradient
and
self
.
config
.
nr_tower
>
1
:
if
self
.
_scale_gradient
and
self
.
config
.
nr_tower
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment