Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1a348f00
Commit
1a348f00
authored
Apr 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
can run
parent
a3674b47
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
421 additions
and
2 deletions
+421
-2
examples/distributed.py
examples/distributed.py
+149
-0
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+3
-2
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+269
-0
No files found.
examples/distributed.py
0 → 100755
View file @
1a348f00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-dist.py
import
numpy
as
np
import
os
import
sys
import
argparse
"""
MNIST ConvNet example.
about 0.6
%
validation error after 30 epochs.
"""
# Just import everything into current namespace
from
tensorpack
import
*
import
tensorflow
as
tf
import
tensorpack.tfutils.symbolic_functions
as
symbf
IMAGE_SIZE
=
28
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return
[
InputDesc
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)]
def
_build_graph
(
self
,
inputs
):
"""This function should build the model which takes the input variables
and define self.cost at the end"""
# inputs contains a list of input variables defined above
image
,
label
=
inputs
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image
=
tf
.
expand_dims
(
image
,
3
)
image
=
image
*
2
-
1
# center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
with
argscope
(
Conv2D
,
kernel_shape
=
3
,
nl
=
tf
.
nn
.
relu
,
out_channel
=
32
):
logits
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
)
.
MaxPooling
(
'pool0'
,
2
)
.
Conv2D
(
'conv1'
)
.
Conv2D
(
'conv2'
)
.
MaxPooling
(
'pool1'
,
2
)
.
Conv2D
(
'conv3'
)
.
FullyConnected
(
'fc0'
,
512
,
nl
=
tf
.
nn
.
relu
)
.
Dropout
(
'dropout'
,
0.5
)
.
FullyConnected
(
'fc1'
,
out_dim
=
10
,
nl
=
tf
.
identity
)())
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'prob'
)
# a Bx10 with probabilities
# a vector of length B with loss of each sample
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
# the average cross-entropy loss
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong
=
symbf
.
prediction_incorrect
(
logits
,
label
,
name
=
'incorrect'
)
accuracy
=
symbf
.
accuracy
(
logits
,
label
,
name
=
'accuracy'
)
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error
=
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
)
summary
.
add_moving_summary
(
train_error
,
accuracy
)
# Use a regex to find parameters to apply weight decay.
# Here we apply a weight decay on all W (weight matrix) of all fc layers
wd_cost
=
tf
.
multiply
(
1e-5
,
regularize_cost
(
'fc.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'regularize_loss'
)
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'total_cost'
)
summary
.
add_moving_summary
(
cost
,
wd_cost
,
self
.
cost
)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary
.
add_param_summary
((
'.*/W'
,
[
'histogram'
,
'rms'
]))
def
_get_optimizer
(
self
):
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-3
,
global_step
=
get_global_step_var
(),
decay_steps
=
468
*
10
,
decay_rate
=
0.3
,
staircase
=
True
,
name
=
'learning_rate'
)
# This will also put the summary in tensorboard, stat.json and print in terminal
# but this time without moving average
tf
.
summary
.
scalar
(
'lr'
,
lr
)
return
tf
.
train
.
AdamOptimizer
(
lr
)
def
get_data
():
train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
return
train
,
test
def
get_config
():
# automatically setup the directory train_log/mnist-convnet for logging
logger
.
auto_set_dir
(
'k'
)
dataset_train
,
dataset_test
=
get_data
()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch
=
dataset_train
.
size
()
# get the config which contains everything necessary in a training
return
TrainConfig
(
model
=
Model
(),
dataflow
=
dataset_train
,
# the DataFlow instance for training
callbacks
=
[
#ModelSaver(), # save the model after every epoch
#MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
#InferenceRunner( # run inference(for validation) after every epoch
#dataset_test, # the DataFlow instance used for validation
## Calculate both the cost and the error for this DataFlow
#[ScalarStats('cross_entropy_loss'), ScalarStats('accuracy'),
#ClassificationError('incorrect')]),
],
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
100
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--job'
,
required
=
True
)
parser
.
add_argument
(
'--task'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
cluster_spec
=
tf
.
train
.
ClusterSpec
({
'ps'
:
[
'0.0.0.0:2222'
],
'worker'
:
[
'0.0.0.0:2223'
,
'0.0.0.0:2224'
]
})
config
.
data
=
QueueInput
(
config
.
dataflow
)
DistributedReplicatedTrainer
(
config
,
args
.
job
,
args
.
task
,
cluster_spec
)
.
train
()
tensorpack/tfutils/summary.py
View file @
1a348f00
...
@@ -156,9 +156,10 @@ def add_moving_summary(v, *args, **kwargs):
...
@@ -156,9 +156,10 @@ def add_moving_summary(v, *args, **kwargs):
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
# TODO will produce tower0/xxx?
# TODO will produce tower0/xxx?
# TODO use zero_debias
# TODO use zero_debias
with
tf
.
name_scope
(
None
):
gs
=
get_global_step_var
()
with
tf
.
name_scope
(
None
),
tf
.
device
(
gs
.
device
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
averager
=
tf
.
train
.
ExponentialMovingAverage
(
decay
,
num_updates
=
g
et_global_step_var
()
,
name
=
'EMA'
)
decay
,
num_updates
=
g
s
,
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
for
c
in
v
:
...
...
tensorpack/train/distributed.py
0 → 100644
View file @
1a348f00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import
tensorflow
as
tf
from
six.moves
import
range
import
weakref
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
..utils
import
logger
from
.input_source
import
StagingInputWrapper
,
FeedfreeInput
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.multigpu
import
MultiGPUTrainerBase
from
..tfutils.model_utils
import
describe_model
from
..callbacks
import
Callbacks
,
ProgressBar
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.common
import
get_default_sess_config
,
get_global_step_var
from
..callbacks.monitor
import
Monitors
__all__
=
[
'DistributedReplicatedTrainer'
]
PS_SHADOW_VAR_PREFIX
=
'ps_var'
# To be used with custom_getter on tf.get_variable. Ensures the created variable
# is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
class
OverrideToLocalVariableIfNotPsVar
(
object
):
# args and kwargs come from the custom_getter interface for Tensorflow
# variables, and matches tf.get_variable's signature, with the addition of
# 'getter' at the beginning.
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
name
.
startswith
(
PS_SHADOW_VAR_PREFIX
):
return
getter
(
*
args
,
**
kwargs
)
logger
.
info
(
"CustomGetter-{}"
.
format
(
name
))
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
return
getter
(
name
,
*
args
,
**
kwargs
)
class
DistributedReplicatedTrainer
(
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
,
job_name
,
task_index
,
cluster
):
assert
job_name
in
[
'ps'
,
'worker'
],
job_name
self
.
config
=
config
self
.
job_name
=
job_name
self
.
task_index
=
task_index
self
.
cluster
=
cluster
self
.
_input_source
=
config
.
data
super
(
DistributedReplicatedTrainer
,
self
)
.
__init__
(
config
)
worker_prefix
=
'/job:worker/task:
%
s'
%
self
.
task_index
self
.
param_server_device
=
tf
.
train
.
replica_device_setter
(
worker_device
=
worker_prefix
+
'/cpu:0'
,
cluster
=
self
.
cluster
)
# This device on which the queues for managing synchronization between
# servers should be stored.
num_ps
=
self
.
cluster
.
num_tasks
(
'ps'
)
self
.
cpu_device
=
'
%
s/cpu:0'
%
worker_prefix
self
.
nr_gpu
=
config
.
nr_tower
self
.
raw_devices
=
[
'
%
s/
%
s:
%
i'
%
(
worker_prefix
,
'gpu'
,
i
)
for
i
in
range
(
self
.
nr_gpu
)]
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
num_ps
)]
self
.
sync_queue_counter
=
0
if
self
.
nr_gpu
>
1
:
assert
tf
.
test
.
is_gpu_available
()
# seem to only improve on >1 GPUs
if
not
isinstance
(
self
.
_input_source
,
StagingInputWrapper
):
self
.
_input_source
=
StagingInputWrapper
(
self
.
_input_source
,
self
.
raw_devices
)
def
_setup
(
self
):
conf
=
get_default_sess_config
()
self
.
server
=
tf
.
train
.
Server
(
self
.
cluster
,
job_name
=
self
.
job_name
,
task_index
=
self
.
task_index
,
config
=
conf
# TODO sessconfig
)
if
self
.
job_name
==
'ps'
:
logger
.
info
(
"Running ps {}"
.
format
(
self
.
task_index
))
self
.
server
.
join
()
return
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariableIfNotPsVar
()):
# Ngpu * Nvar * 2
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
devices
=
self
.
raw_devices
,
var_strategy
=
'replicated'
)
# (g, v) to be applied, where v is global (ps vars)
new_tower_grads
=
[]
for
i
,
grad_and_vars
in
enumerate
(
zip
(
*
grad_list
)):
# Ngpu * 2
with
tf
.
device
(
self
.
raw_devices
[
i
%
self
.
nr_gpu
]):
v
=
grad_and_vars
[
0
][
1
]
if
self
.
nr_gpu
>
1
:
# average gradient
all_grads
=
[
g
for
(
g
,
_
)
in
grad_and_vars
]
if
not
MultiGPUTrainerBase
.
check_none_grads
(
v
.
op
.
name
,
all_grads
):
continue
grad
=
tf
.
multiply
(
tf
.
add_n
(
all_grads
),
1.0
/
self
.
nr_gpu
)
else
:
grad
=
grad_and_vars
[
0
][
0
]
with
tf
.
device
(
self
.
param_server_device
):
my_name
=
PS_SHADOW_VAR_PREFIX
+
'/'
+
v
.
name
if
my_name
.
endswith
(
':0'
):
my_name
=
my_name
[:
-
2
]
new_v
=
tf
.
get_variable
(
my_name
,
dtype
=
v
.
dtype
.
base_dtype
,
initializer
=
v
.
initial_value
,
trainable
=
True
)
new_tower_grads
.
append
((
grad
,
new_v
))
# apply gradients TODO do this for each variable separately?
opt
=
self
.
model
.
get_optimizer
()
apply_gradient_op
=
opt
.
apply_gradients
(
new_tower_grads
)
barrier
=
self
.
add_sync_queues_and_barrier
(
'replicate_variable'
,
[
apply_gradient_op
])
var_update_ops
=
[]
with
tf
.
control_dependencies
([
barrier
]),
\
tf
.
device
(
self
.
cpu_device
):
for
idx
,
(
grad
,
v
)
in
enumerate
(
new_tower_grads
):
updated_value
=
v
.
read_value
()
for
towerid
in
range
(
self
.
nr_gpu
):
logger
.
info
(
"Step update {} -> {}"
.
format
(
v
.
name
,
grad_list
[
towerid
][
idx
][
1
]
.
name
))
var_update_ops
.
append
(
grad_list
[
towerid
][
idx
][
1
]
.
assign
(
updated_value
))
self
.
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'sync_queues_step_end'
,
[
self
.
main_fetch
])
self
.
post_init_op
=
self
.
get_post_init_ops
()
def
setup
(
self
):
with
tf
.
device
(
self
.
param_server_device
):
gs
=
get_global_step_var
()
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
self
.
_input_source
.
setup_training
(
self
)
self
.
_setup
()
self
.
monitors
=
Monitors
(
self
.
monitors
)
self
.
register_callback
(
self
.
monitors
)
describe_model
()
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
#if not self.is_chief:
#self._callbacks = [ProgressBar()]
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
#local_init_op = tf.local_variables_initializer()
global_init_op
=
tf
.
global_variables_initializer
()
logger
.
info
(
"Finalize the graph, create the session ..."
)
self
.
sv
=
tf
.
train
.
Supervisor
(
is_chief
=
self
.
is_chief
,
logdir
=
None
,
saver
=
None
,
global_step
=
gs
,
summary_op
=
None
,
save_model_secs
=
0
,
#local_init_op=local_init_op,
#ready_for_local_init_op=None,
summary_writer
=
None
)
conf
=
get_default_sess_config
()
sess
=
self
.
sv
.
prepare_or_wait_for_session
(
master
=
self
.
server
.
target
,
config
=
conf
,
start_standard_services
=
False
)
self
.
sess
=
sess
if
self
.
is_chief
:
print
([
k
.
name
for
k
in
tf
.
global_variables
()])
sess
.
run
(
global_init_op
)
logger
.
info
(
"Global variables initialized."
)
#sess.run(local_init_op)
#if self.is_chief:
#self.config.session_init.init(self.sess)
#self.sess.graph.finalize()
#else:
#logger.info("Worker {} waiting for chief".format(self.task_index))
#self.sess = tf.train.WorkerSessionCreator(master=self.server.target).create_session()
#logger.info("Worker wait finished")
#self.sess.run(local_init_op)
#logger.info("local init op runned")
logger
.
info
(
"Running post init op..."
)
sess
.
run
(
self
.
post_init_op
)
logger
.
info
(
"Post init op finished."
)
self
.
_monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
None
)
#self._monitored_sess = self.sv
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
HookedSession
(
self
.
sess
,
hooks
)
def
add_sync_queues_and_barrier
(
self
,
name_prefix
,
enqueue_after_list
):
"""Adds ops to enqueue on all worker queues.
Args:
name_prefix: prefixed for the shared_name of ops.
enqueue_after_list: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self
.
sync_queue_counter
+=
1
num_workers
=
self
.
cluster
.
num_tasks
(
'worker'
)
with
tf
.
device
(
self
.
sync_queue_devices
[
self
.
sync_queue_counter
%
len
(
self
.
sync_queue_devices
)]):
sync_queues
=
[
tf
.
FIFOQueue
(
num_workers
,
[
tf
.
bool
],
shapes
=
[[]],
shared_name
=
'
%
s
%
s'
%
(
name_prefix
,
i
))
for
i
in
range
(
num_workers
)]
queue_ops
=
[]
# For each other worker, add an entry in a queue, signaling that it can
# finish this step.
token
=
tf
.
constant
(
False
)
with
tf
.
control_dependencies
(
enqueue_after_list
):
for
i
,
q
in
enumerate
(
sync_queues
):
if
i
==
self
.
task_index
:
queue_ops
.
append
(
tf
.
no_op
())
else
:
queue_ops
.
append
(
q
.
enqueue
(
token
))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops
.
append
(
sync_queues
[
self
.
task_index
]
.
dequeue_many
(
len
(
sync_queues
)
-
1
))
return
tf
.
group
(
*
queue_ops
)
def
get_post_init_ops
(
self
):
# Copy initialized variables for variables on the parameter server
# to the local copy of the variable.
def
strip_port
(
s
):
if
s
.
endswith
(
':0'
):
return
s
[:
-
2
]
return
s
local_vars
=
tf
.
local_variables
()
local_var_by_name
=
dict
([(
strip_port
(
v
.
name
),
v
)
for
v
in
local_vars
])
post_init_ops
=
[]
for
v
in
tf
.
global_variables
():
if
v
.
name
.
startswith
(
PS_SHADOW_VAR_PREFIX
+
'/'
):
prefix
=
strip_port
(
v
.
name
[
len
(
PS_SHADOW_VAR_PREFIX
+
'/'
):])
for
i
in
range
(
self
.
nr_gpu
):
if
i
==
0
:
name
=
prefix
else
:
name
=
'tower
%
s/
%
s'
%
(
i
,
prefix
)
if
name
in
local_var_by_name
:
copy_to
=
local_var_by_name
[
name
]
logger
.
info
(
"Post Init {} -> {}"
.
format
(
v
.
name
,
copy_to
.
name
))
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
else
:
logger
.
warn
(
"Global var {} doesn't match local var"
.
format
(
v
.
name
))
return
tf
.
group
(
*
post_init_ops
,
name
=
'post_init_ops'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment