Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6d7276b8
Commit
6d7276b8
authored
Aug 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use LOCAL_VARIABLES in replicated trainer, so duplicated vars won't get saved
parent
4fa66545
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
61 deletions
+86
-61
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+1
-0
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+1
-1
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+4
-21
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+13
-39
tensorpack/train/utility.py
tensorpack/train/utility.py
+67
-0
No files found.
tensorpack/tfutils/sesscreate.py
View file @
6d7276b8
...
...
@@ -31,6 +31,7 @@ class NewSessionCreator(tf.train.SessionCreator):
def
create_session
(
self
):
sess
=
tf
.
Session
(
target
=
self
.
target
,
graph
=
self
.
graph
,
config
=
self
.
config
)
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
logger
.
info
(
"Global variables initialized."
)
return
sess
...
...
tensorpack/tfutils/tower.py
View file @
6d7276b8
...
...
@@ -20,7 +20,7 @@ class TowerContext(object):
tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower, only used in training.
use_vs (bool): Open a variable scope with this name.
use_vs (bool): Open a
new
variable scope with this name.
"""
self
.
_name
=
tower_name
self
.
_is_training
=
bool
(
is_training
)
...
...
tensorpack/train/distributed.py
View file @
6d7276b8
...
...
@@ -8,30 +8,15 @@ import os
from
six.moves
import
range
from
..utils
import
logger
from
.multigpu
import
MultiGPUTrainerBase
from
..callbacks
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
__all__
=
[
'DistributedReplicatedTrainer'
,
'DistributedTrainerReplicated'
]
from
.multigpu
import
MultiGPUTrainerBase
from
.utility
import
override_to_local_variable
class
OverrideToLocalVariable
(
object
):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
return
getter
(
name
,
*
args
,
**
kwargs
)
__all__
=
[
'DistributedReplicatedTrainer'
,
'DistributedTrainerReplicated'
]
class
DistributedTrainerReplicated
(
MultiGPUTrainerBase
):
...
...
@@ -220,9 +205,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariable
()):
with
override_to_local_variable
():
# Ngpu * Nvar * 2
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
...
...
tensorpack/train/multigpu.py
View file @
6d7276b8
...
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
operator
from
six.moves
import
zip
,
range
from
..utils
import
logger
...
...
@@ -17,6 +16,7 @@ from ..callbacks.graph import RunOp
from
..graph_builder.input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
.base
import
Trainer
from
.utility
import
LeastLoadedDeviceSetter
,
override_to_local_variable
__all__
=
[
'MultiGPUTrainerBase'
,
'LeastLoadedDeviceSetter'
,
'SyncMultiGPUTrainerReplicated'
,
...
...
@@ -69,25 +69,28 @@ class MultiGPUTrainerBase(Trainer):
ret
=
[]
if
devices
is
not
None
:
assert
len
(
devices
)
==
len
(
towers
)
if
use_vs
is
not
None
:
assert
len
(
use_vs
)
==
len
(
towers
)
tower_names
=
[
'tower{}'
.
format
(
idx
)
for
idx
in
range
(
len
(
towers
))]
keys_to_freeze
=
TOWER_FREEZE_KEYS
[:]
if
use_vs
is
None
:
use_vs
=
[
False
]
*
len
(
towers
)
assert
len
(
use_vs
)
==
len
(
towers
)
for
idx
,
t
in
enumerate
(
towers
):
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
usevs
=
use_vs
[
idx
]
if
use_vs
is
not
None
else
False
with
tf
.
device
(
device
),
TowerContext
(
tower_names
[
idx
],
is_training
=
True
,
index
=
idx
,
use_vs
=
use
_vs
[
idx
]
):
use_vs
=
use
vs
):
if
idx
==
t
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
else
:
logger
.
info
(
"Building graph for training tower {} on device {}..."
.
format
(
idx
,
device
))
# When use_vs is True, use LOCAL_VARIABLES,
# so these duplicated variables won't be saved by default.
with
override_to_local_variable
(
enable
=
usevs
):
ret
.
append
(
func
())
if
idx
==
0
:
...
...
@@ -111,37 +114,6 @@ class MultiGPUTrainerBase(Trainer):
return
model
.
get_cost_and_grad
()[
1
]
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class
LeastLoadedDeviceSetter
(
object
):
""" Helper class to assign variables on the least loaded ps-device."""
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self
.
ps_devices
=
ps_devices
self
.
worker_device
=
worker_device
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
def
sanitize_name
(
name
):
# tensorflow/tensorflow#11484
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
if
op
.
device
:
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
return
sanitize_name
(
self
.
worker_device
)
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
device_name
=
self
.
ps_devices
[
device_index
]
var_size
=
op
.
outputs
[
0
]
.
get_shape
()
.
num_elements
()
self
.
ps_sizes
[
device_index
]
+=
var_size
return
sanitize_name
(
device_name
)
class
SyncMultiGPUTrainerParameterServer
(
MultiGPUTrainerBase
):
"""
A data-parallel multi-GPU trainer. It builds one tower on each GPU with
...
...
@@ -308,6 +280,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
for
idx
in
range
(
len
(
tower
)):
with
tf
.
device
(
raw_devices
[
idx
]):
grad_and_vars
=
[
x
[
idx
]
for
x
in
grads
]
# apply_gradients may create variables. Make them LOCAL_VARIABLES
with
override_to_local_variable
(
enable
=
idx
>
0
):
train_ops
.
append
(
opt
.
apply_gradients
(
grad_and_vars
,
name
=
'apply_grad_{}'
.
format
(
idx
)))
train_op
=
tf
.
group
(
*
train_ops
,
name
=
'train_op'
)
...
...
tensorpack/train/utility.py
0 → 100644
View file @
6d7276b8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
import
tensorflow
as
tf
from
contextlib
import
contextmanager
import
operator
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
if
enable
:
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariable
()):
yield
else
:
yield
class
OverrideToLocalVariable
(
object
):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
return
getter
(
name
,
*
args
,
**
kwargs
)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class
LeastLoadedDeviceSetter
(
object
):
""" Helper class to assign variables on the least loaded ps-device."""
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self
.
ps_devices
=
ps_devices
self
.
worker_device
=
worker_device
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
def
sanitize_name
(
name
):
# tensorflow/tensorflow#11484
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
if
op
.
device
:
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
return
sanitize_name
(
self
.
worker_device
)
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
device_name
=
self
.
ps_devices
[
device_index
]
var_size
=
op
.
outputs
[
0
]
.
get_shape
()
.
num_elements
()
self
.
ps_sizes
[
device_index
]
+=
var_size
return
sanitize_name
(
device_name
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment