Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6bee3c24
Commit
6bee3c24
authored
Dec 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add overridecachingdevice & colocate option
parent
6efe0deb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
12 deletions
+51
-12
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+3
-6
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+48
-6
No files found.
tensorpack/graph_builder/training.py
View file @
6bee3c24
...
@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient
...
@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient
from
.utils
import
(
from
.utils
import
(
LeastLoadedDeviceSetter
,
override_to_local_variable
,
LeastLoadedDeviceSetter
,
override_to_local_variable
,
allreduce_grads
,
average_grads
_with_colocation
)
allreduce_grads
,
average_grads
)
__all__
=
[
'GraphBuilder'
,
__all__
=
[
'GraphBuilder'
,
...
@@ -109,16 +109,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
...
@@ -109,16 +109,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
It is an equivalent of ``--variable_update=parameter_server`` in
It is an equivalent of ``--variable_update=parameter_server`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
"""
"""
def
__init__
(
self
,
towers
,
ps_device
=
None
):
def
__init__
(
self
,
towers
,
ps_device
):
"""
"""
Args:
Args:
towers(list[int]): list of GPU id
towers(list[int]): list of GPU id
ps_device (str): either 'gpu' or 'cpu', where variables are stored.
ps_device (str): either 'gpu' or 'cpu', where variables are stored.
Setting to 'cpu' might help when #gpu>=4
"""
"""
super
(
SyncMultiGPUParameterServerBuilder
,
self
)
.
__init__
(
towers
)
super
(
SyncMultiGPUParameterServerBuilder
,
self
)
.
__init__
(
towers
)
if
ps_device
is
None
:
ps_device
=
'cpu'
if
len
(
towers
)
>=
4
else
'gpu'
assert
ps_device
in
[
'cpu'
,
'gpu'
]
assert
ps_device
in
[
'cpu'
,
'gpu'
]
self
.
ps_device
=
ps_device
self
.
ps_device
=
ps_device
...
@@ -146,7 +143,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
...
@@ -146,7 +143,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops)
# self.train_op = tf.group(*ops)
# return
# return
grads
=
average_grads
_with_colocation
(
grad_list
)
grads
=
average_grads
(
grad_list
,
colocate
=
True
)
# grads = grad_list[0]
# grads = grad_list[0]
opt
=
get_opt_fn
()
opt
=
get_opt_fn
()
...
...
tensorpack/graph_builder/utils.py
View file @
6bee3c24
...
@@ -8,9 +8,11 @@ import operator
...
@@ -8,9 +8,11 @@ import operator
import
tensorflow
as
tf
import
tensorflow
as
tf
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideToLocalVariable'
,
__all__
=
[
'LeastLoadedDeviceSetter'
,
'override_to_local_variable'
,
'allreduce_grads'
,
'OverrideCachingDevice'
,
'average_grads_with_colocation'
]
'OverrideToLocalVariable'
,
'override_to_local_variable'
,
'allreduce_grads'
,
'average_grads'
]
"""
"""
...
@@ -115,13 +117,14 @@ def allreduce_grads(all_grads):
...
@@ -115,13 +117,14 @@ def allreduce_grads(all_grads):
return
ret
return
ret
def
average_grads
_with_colocation
(
all_grads
):
def
average_grads
(
all_grads
,
colocation
=
True
):
"""
"""
Average the gradients, on the device of each variable.
Average the gradients, on the device of each variable.
Args:
Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists.
The variables have to be the same across the K lists.
colocation (bool): colocate gradient averaging with the variable
Returns:
Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
...
@@ -137,8 +140,47 @@ def average_grads_with_colocation(all_grads):
...
@@ -137,8 +140,47 @@ def average_grads_with_colocation(all_grads):
v
=
grad_and_vars
[
0
][
1
]
v
=
grad_and_vars
[
0
][
1
]
grads
=
[
g
for
(
g
,
_
)
in
grad_and_vars
]
grads
=
[
g
for
(
g
,
_
)
in
grad_and_vars
]
with
tf
.
device
(
v
.
device
):
# colocate summed grad with var
if
colocation
:
with
tf
.
device
(
v
.
device
):
# colocate summed grad with var
grad
=
tf
.
multiply
(
tf
.
add_n
(
grads
),
1.0
/
nr_tower
)
else
:
grad
=
tf
.
multiply
(
grad
=
tf
.
multiply
(
tf
.
add_n
(
grads
),
1.0
/
nr_tower
)
tf
.
add_n
(
grads
),
1.0
/
nr_tower
)
ret
.
append
((
grad
,
v
))
ret
.
append
((
grad
,
v
))
return
ret
return
ret
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
class
OverrideCachingDevice
(
object
):
"""Variable getter which caches variables on the least loaded device.
Variables smaller than a certain threshold are cached on a single specific
device, as specified in the constructor. All other variables are load balanced
across a pool of devices, by caching each variable on the least loaded device.
"""
def
__init__
(
self
,
devices
,
device_for_small_variables
,
small_variable_size_threshold
):
self
.
devices
=
devices
self
.
sizes
=
[
0
]
*
len
(
self
.
devices
)
self
.
device_for_small_variables
=
device_for_small_variables
self
.
small_variable_size_threshold
=
small_variable_size_threshold
def
__call__
(
self
,
getter
,
*
args
,
**
kwargs
):
size
=
tf
.
TensorShape
(
kwargs
[
'shape'
])
.
num_elements
()
if
size
is
None
:
# print(args, kwargs)
return
getter
(
*
args
,
**
kwargs
)
if
kwargs
.
get
(
'trainable'
,
True
)
==
False
:
return
getter
(
*
args
,
**
kwargs
)
if
size
<
self
.
small_variable_size_threshold
:
device_name
=
self
.
device_for_small_variables
else
:
device_index
,
_
=
min
(
enumerate
(
self
.
sizes
),
key
=
operator
.
itemgetter
(
1
))
device_name
=
self
.
devices
[
device_index
]
self
.
sizes
[
device_index
]
+=
size
kwargs
[
'caching_device'
]
=
device_name
var
=
getter
(
*
args
,
**
kwargs
)
return
var
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment