Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6cb47609
Commit
6cb47609
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix distributed training, GAN examples. rename utilities
parent
ce709fa3
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
28 additions
and
212 deletions
+28
-212
docs/conf.py
docs/conf.py
+1
-3
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-4
tensorpack/graph_builder/_utils.py
tensorpack/graph_builder/_utils.py
+0
-124
tensorpack/graph_builder/distributed.py
tensorpack/graph_builder/distributed.py
+2
-8
tensorpack/graph_builder/input_source_base.py
tensorpack/graph_builder/input_source_base.py
+1
-1
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+1
-1
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+1
-1
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+15
-6
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
tensorpack/train/utility.py
tensorpack/train/utility.py
+4
-63
No files found.
docs/conf.py
View file @
6cb47609
...
...
@@ -353,9 +353,7 @@ def process_signature(app, what, name, obj, options, signature,
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
if
name
in
[
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'FeedfreeTrainerBase'
,
'MultiGPUTrainerBase'
,
'FeedfreeInferenceRunner'
,
'replace_get_variable'
,
'remap_get_variable'
,
...
...
examples/GAN/GAN.py
View file @
6cb47609
...
...
@@ -8,9 +8,8 @@ import numpy as np
import
time
from
tensorpack
import
(
Trainer
,
QueueInput
,
ModelDescBase
,
DataFlow
,
StagingInputWrapper
,
MultiGPUTrainerBase
,
TowerContext
)
from
tensorpack.
train.utility
import
LeastLoadedDeviceSetter
from
tensorpack.
graph_builder
import
DataParallelBuilder
,
LeastLoadedDeviceSetter
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.argtools
import
memoized
...
...
@@ -146,8 +145,7 @@ class MultiGPUGANTrainer(Trainer):
model
.
build_graph
(
input
)
return
[
model
.
d_loss
,
model
.
g_loss
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
config
.
tower
,
get_cost
,
devices
)
cost_list
=
DataParallelBuilder
.
build_on_towers
(
config
.
tower
,
get_cost
,
devices
)
# simply average the cost. It might get faster to average the gradients
with
tf
.
name_scope
(
'optimize'
):
d_loss
=
tf
.
add_n
([
x
[
0
]
for
x
in
cost_list
])
*
(
1.0
/
nr_gpu
)
...
...
tensorpack/graph_builder/_utils.py
deleted
100644 → 0
View file @
ce709fa3
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
copy
from
six.moves
import
zip
from
contextlib
import
contextmanager
import
operator
import
tensorflow
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'get_tensors_inputs'
,
'get_sublist_by_names'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert
len
(
tensors
)
==
len
(
names
),
\
"Input tensors {} and input names {} have different length!"
.
format
(
tensors
,
names
)
ret
=
copy
.
copy
(
placeholders
)
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
for
name
,
tensor
in
zip
(
names
,
tensors
):
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensorname
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
raise
ret
[
idx
]
=
tensor
return
ret
def
get_sublist_by_names
(
lst
,
names
):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names
=
[
p
.
name
for
p
in
lst
]
ret
=
[]
for
name
in
names
:
try
:
idx
=
orig_names
.
index
(
name
)
except
ValueError
:
logger
.
error
(
"Name {} doesn't appear in lst {}!"
.
format
(
name
,
str
(
orig_names
)))
raise
ret
.
append
(
lst
[
idx
])
return
ret
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
if
enable
:
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariable
()):
yield
else
:
yield
class
OverrideToLocalVariable
(
object
):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
return
getter
(
name
,
*
args
,
**
kwargs
)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class
LeastLoadedDeviceSetter
(
object
):
""" Helper class to assign variables on the least loaded ps-device."""
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self
.
ps_devices
=
ps_devices
self
.
worker_device
=
worker_device
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
def
sanitize_name
(
name
):
# tensorflow/tensorflow#11484
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
if
op
.
device
:
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
return
sanitize_name
(
self
.
worker_device
)
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
device_name
=
self
.
ps_devices
[
device_index
]
var_size
=
op
.
outputs
[
0
]
.
get_shape
()
.
num_elements
()
self
.
ps_sizes
[
device_index
]
+=
var_size
return
sanitize_name
(
device_name
)
tensorpack/graph_builder/distributed.py
View file @
6cb47609
...
...
@@ -6,7 +6,6 @@ import tensorflow as tf
import
re
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
...
...
@@ -24,14 +23,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
self
.
server
=
server
server_def
=
server
.
server_def
self
.
cluster
=
tf
.
train
.
ClusterSpec
(
server_def
.
cluster
)
self
.
job_name
=
server_def
.
job_name
self
.
task_index
=
server_def
.
task_index
# TODO XXX ps does't need to build!
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server_def
.
cluster
))
logger
.
info
(
"My role in the cluster: job={}, task={}"
.
format
(
self
.
job_name
,
self
.
task_index
))
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
self
.
is_chief
=
(
self
.
task_index
==
0
)
worker_prefix
=
'/job:worker/task:
%
s'
%
self
.
task_index
self
.
param_server_device
=
tf
.
train
.
replica_device_setter
(
...
...
@@ -152,10 +146,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return
tf
.
group
(
*
queue_ops
,
name
=
name
)
def
build
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
# do this before everything, because they my need global step
with
tf
.
device
(
self
.
param_server_device
):
gs
=
get_global_step_var
()
assert
gs
.
device
,
gs
.
device
# do this before inputsource.setup because input_source my need global step
get_opt_fn
=
memoized
(
get_opt_fn
)
# Build the optimizer first, before entering any tower.
...
...
tensorpack/graph_builder/input_source_base.py
View file @
6cb47609
...
...
@@ -8,7 +8,7 @@ from contextlib import contextmanager
import
tensorflow
as
tf
from
..utils.argtools
import
memoized
from
.
_
utils
import
get_sublist_by_names
,
get_tensors_inputs
from
.utils
import
get_sublist_by_names
,
get_tensors_inputs
from
..callbacks.base
import
CallbackFactory
__all__
=
[
'InputSource'
,
'remap_input_source'
]
...
...
tensorpack/graph_builder/training.py
View file @
6cb47609
...
...
@@ -15,7 +15,7 @@ from ..tfutils.common import get_tf_version_number
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.gradproc
import
ScaleGradient
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
.
_
utils
import
LeastLoadedDeviceSetter
,
override_to_local_variable
from
.utils
import
LeastLoadedDeviceSetter
,
override_to_local_variable
__all__
=
[
'GraphBuilder'
,
'SimpleBuilder'
,
...
...
tensorpack/train/__init__.py
View file @
6cb47609
...
...
@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_SKIP
=
[]
_SKIP
=
[
'utility'
]
for
_
,
module_name
,
_
in
iter_modules
(
[
_CURR_DIR
]):
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
...
...
tensorpack/train/distributed.py
View file @
6cb47609
...
...
@@ -11,7 +11,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from
..tfutils.common
import
get_global_step_var
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
.
utility
import
override_to_local_variable
from
.
.graph_builder.utils
import
override_to_local_variable
from
.base
import
Trainer
...
...
@@ -63,25 +63,34 @@ class DistributedTrainerReplicated(Trainer):
assert
config
.
data
is
not
None
and
config
.
model
is
not
None
self
.
server
=
server
self
.
job_name
=
server
.
server_def
.
job_name
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
config
.
tower
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
self
.
_input_source
=
config
.
data
self
.
is_chief
=
self
.
_builder
.
is_chief
self
.
nr_gpu
=
config
.
nr_tower
super
(
DistributedTrainerReplicated
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
if
self
.
_builder
.
job_name
==
'ps'
:
if
self
.
job_name
==
'ps'
:
logger
.
info
(
"Running ps {}"
.
format
(
self
.
_builder
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this will never return tensorflow#4713
return
# always do this before inputsource.setup because input_source my need global step
# TODO Can we just do this in get_global_step_var
with
tf
.
device
(
self
.
_builder
.
param_server_device
):
gs
=
get_global_step_var
()
assert
gs
.
device
,
gs
.
device
# always do this before inputsource.setup because input_source my need global step
with
override_to_local_variable
():
# input source may create variable (queue size summary)
...
...
tensorpack/train/multigpu.py
View file @
6cb47609
...
...
@@ -27,7 +27,7 @@ class MultiGPUTrainerBase(Trainer):
For backward compatibility only
"""
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
use_vs
=
None
):
DataParallelBuilder
.
build_on_towers
(
towers
,
func
,
devices
,
use_vs
)
return
DataParallelBuilder
.
build_on_towers
(
towers
,
func
,
devices
,
use_vs
)
def
apply_prefetch_policy
(
config
,
gpu_prefetch
=
True
):
...
...
tensorpack/train/utility.py
View file @
6cb47609
...
...
@@ -2,66 +2,7 @@
# -*- coding: utf-8 -*-
# File: utility.py
import
tensorflow
as
tf
from
contextlib
import
contextmanager
import
operator
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
if
enable
:
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariable
()):
yield
else
:
yield
class
OverrideToLocalVariable
(
object
):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
return
getter
(
name
,
*
args
,
**
kwargs
)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class
LeastLoadedDeviceSetter
(
object
):
""" Helper class to assign variables on the least loaded ps-device."""
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self
.
ps_devices
=
ps_devices
self
.
worker_device
=
worker_device
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
def
sanitize_name
(
name
):
# tensorflow/tensorflow#11484
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
if
op
.
device
:
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
return
sanitize_name
(
self
.
worker_device
)
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
device_name
=
self
.
ps_devices
[
device_index
]
var_size
=
op
.
outputs
[
0
]
.
get_shape
()
.
num_elements
()
self
.
ps_sizes
[
device_index
]
+=
var_size
return
sanitize_name
(
device_name
)
# for backwards-compatibility
from
..graph_builder.utils
import
(
# noqa
OverrideToLocalVariable
,
override_to_local_variable
,
LeastLoadedDeviceSetter
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment