Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1a348f00
Commit
1a348f00
authored
Apr 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
can run
parent
a3674b47
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
421 additions
and
2 deletions
+421
-2
examples/distributed.py
examples/distributed.py
+149
-0
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+3
-2
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+269
-0
No files found.
examples/distributed.py
0 → 100755
View file @
1a348f00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-dist.py
import
numpy
as
np
import
os
import
sys
import
argparse
"""
MNIST ConvNet example.
about 0.6
%
validation error after 30 epochs.
"""
# Just import everything into current namespace
from
tensorpack
import
*
import
tensorflow
as
tf
import
tensorpack.tfutils.symbolic_functions
as
symbf
IMAGE_SIZE
=
28
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return
[
InputDesc
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)]
def
_build_graph
(
self
,
inputs
):
"""This function should build the model which takes the input variables
and define self.cost at the end"""
# inputs contains a list of input variables defined above
image
,
label
=
inputs
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image
=
tf
.
expand_dims
(
image
,
3
)
image
=
image
*
2
-
1
# center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
with
argscope
(
Conv2D
,
kernel_shape
=
3
,
nl
=
tf
.
nn
.
relu
,
out_channel
=
32
):
logits
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
)
.
MaxPooling
(
'pool0'
,
2
)
.
Conv2D
(
'conv1'
)
.
Conv2D
(
'conv2'
)
.
MaxPooling
(
'pool1'
,
2
)
.
Conv2D
(
'conv3'
)
.
FullyConnected
(
'fc0'
,
512
,
nl
=
tf
.
nn
.
relu
)
.
Dropout
(
'dropout'
,
0.5
)
.
FullyConnected
(
'fc1'
,
out_dim
=
10
,
nl
=
tf
.
identity
)())
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'prob'
)
# a Bx10 with probabilities
# a vector of length B with loss of each sample
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
# the average cross-entropy loss
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong
=
symbf
.
prediction_incorrect
(
logits
,
label
,
name
=
'incorrect'
)
accuracy
=
symbf
.
accuracy
(
logits
,
label
,
name
=
'accuracy'
)
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error
=
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
)
summary
.
add_moving_summary
(
train_error
,
accuracy
)
# Use a regex to find parameters to apply weight decay.
# Here we apply a weight decay on all W (weight matrix) of all fc layers
wd_cost
=
tf
.
multiply
(
1e-5
,
regularize_cost
(
'fc.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'regularize_loss'
)
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'total_cost'
)
summary
.
add_moving_summary
(
cost
,
wd_cost
,
self
.
cost
)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary
.
add_param_summary
((
'.*/W'
,
[
'histogram'
,
'rms'
]))
def
_get_optimizer
(
self
):
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-3
,
global_step
=
get_global_step_var
(),
decay_steps
=
468
*
10
,
decay_rate
=
0.3
,
staircase
=
True
,
name
=
'learning_rate'
)
# This will also put the summary in tensorboard, stat.json and print in terminal
# but this time without moving average
tf
.
summary
.
scalar
(
'lr'
,
lr
)
return
tf
.
train
.
AdamOptimizer
(
lr
)
def
get_data
():
train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
return
train
,
test
def
get_config
():
# automatically setup the directory train_log/mnist-convnet for logging
logger
.
auto_set_dir
(
'k'
)
dataset_train
,
dataset_test
=
get_data
()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch
=
dataset_train
.
size
()
# get the config which contains everything necessary in a training
return
TrainConfig
(
model
=
Model
(),
dataflow
=
dataset_train
,
# the DataFlow instance for training
callbacks
=
[
#ModelSaver(), # save the model after every epoch
#MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
#InferenceRunner( # run inference(for validation) after every epoch
#dataset_test, # the DataFlow instance used for validation
## Calculate both the cost and the error for this DataFlow
#[ScalarStats('cross_entropy_loss'), ScalarStats('accuracy'),
#ClassificationError('incorrect')]),
],
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
100
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--job'
,
required
=
True
)
parser
.
add_argument
(
'--task'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
cluster_spec
=
tf
.
train
.
ClusterSpec
({
'ps'
:
[
'0.0.0.0:2222'
],
'worker'
:
[
'0.0.0.0:2223'
,
'0.0.0.0:2224'
]
})
config
.
data
=
QueueInput
(
config
.
dataflow
)
DistributedReplicatedTrainer
(
config
,
args
.
job
,
args
.
task
,
cluster_spec
)
.
train
()
tensorpack/tfutils/summary.py
View file @
1a348f00
...
...
@@ -156,9 +156,10 @@ def add_moving_summary(v, *args, **kwargs):
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
# TODO will produce tower0/xxx?
# TODO use zero_debias
with
tf
.
name_scope
(
None
):
gs
=
get_global_step_var
()
with
tf
.
name_scope
(
None
),
tf
.
device
(
gs
.
device
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
decay
,
num_updates
=
g
et_global_step_var
()
,
name
=
'EMA'
)
decay
,
num_updates
=
g
s
,
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
...
...
tensorpack/train/distributed.py
0 → 100644
View file @
1a348f00
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment