Commit 1a348f00 authored by Yuxin Wu's avatar Yuxin Wu

can run

parent a3674b47
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-dist.py
import numpy as np
import os
import sys
import argparse
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
# Just import everything into current namespace
from tensorpack import *
import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf
IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
"""This function should build the model which takes the input variables
and define self.cost at the end"""
# inputs contains a list of input variables defined above
image, label = inputs
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
logits = (LinearWrap(image)
.Conv2D('conv0')
.MaxPooling('pool0', 2)
.Conv2D('conv1')
.Conv2D('conv2')
.MaxPooling('pool1', 2)
.Conv2D('conv3')
.FullyConnected('fc0', 512, nl=tf.nn.relu)
.Dropout('dropout', 0.5)
.FullyConnected('fc1', out_dim=10, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbf.prediction_incorrect(logits, label, name='incorrect')
accuracy = symbf.accuracy(logits, label, name='accuracy')
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error, accuracy)
# Use a regex to find parameters to apply weight decay.
# Here we apply a weight decay on all W (weight matrix) of all fc layers
wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']))
def _get_optimizer(self):
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
# This will also put the summary in tensorboard, stat.json and print in terminal
# but this time without moving average
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr)
def get_data():
train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True)
return train, test
def get_config():
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir('k')
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch = dataset_train.size()
# get the config which contains everything necessary in a training
return TrainConfig(
model=Model(),
dataflow=dataset_train, # the DataFlow instance for training
callbacks=[
#ModelSaver(), # save the model after every epoch
#MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
#InferenceRunner( # run inference(for validation) after every epoch
#dataset_test, # the DataFlow instance used for validation
## Calculate both the cost and the error for this DataFlow
#[ScalarStats('cross_entropy_loss'), ScalarStats('accuracy'),
#ClassificationError('incorrect')]),
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--job', required=True)
parser.add_argument('--task', type=int, default=0)
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
cluster_spec = tf.train.ClusterSpec({
'ps': ['0.0.0.0:2222'],
'worker': ['0.0.0.0:2223', '0.0.0.0:2224']
})
config.data = QueueInput(config.dataflow)
DistributedReplicatedTrainer(config, args.job, args.task, cluster_spec).train()
...@@ -156,9 +156,10 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -156,9 +156,10 @@ def add_moving_summary(v, *args, **kwargs):
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx? # TODO will produce tower0/xxx?
# TODO use zero_debias # TODO use zero_debias
with tf.name_scope(None): gs = get_global_step_var()
with tf.name_scope(None), tf.device(gs.device):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA') decay, num_updates=gs, name='EMA')
avg_maintain_op = averager.apply(v) avg_maintain_op = averager.apply(v)
for c in v: for c in v:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
from six.moves import range
import weakref
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from ..utils import logger
from .input_source import StagingInputWrapper, FeedfreeInput
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase
from ..tfutils.model_utils import describe_model
from ..callbacks import Callbacks, ProgressBar
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.common import get_default_sess_config, get_global_step_var
from ..callbacks.monitor import Monitors
__all__ = ['DistributedReplicatedTrainer']
PS_SHADOW_VAR_PREFIX = 'ps_var'
# To be used with custom_getter on tf.get_variable. Ensures the created variable
# is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
class OverrideToLocalVariableIfNotPsVar(object):
# args and kwargs come from the custom_getter interface for Tensorflow
# variables, and matches tf.get_variable's signature, with the addition of
# 'getter' at the beginning.
def __call__(self, getter, name, *args, **kwargs):
if name.startswith(PS_SHADOW_VAR_PREFIX):
return getter(*args, **kwargs)
logger.info("CustomGetter-{}".format(name))
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def __init__(self, config, job_name, task_index, cluster):
assert job_name in ['ps', 'worker'], job_name
self.config = config
self.job_name = job_name
self.task_index = task_index
self.cluster = cluster
self._input_source = config.data
super(DistributedReplicatedTrainer, self).__init__(config)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
# This device on which the queues for managing synchronization between
# servers should be stored.
num_ps = self.cluster.num_tasks('ps')
self.cpu_device = '%s/cpu:0' % worker_prefix
self.nr_gpu = config.nr_tower
self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)]
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(num_ps)]
self.sync_queue_counter = 0
if self.nr_gpu > 1:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(self._input_source, StagingInputWrapper):
self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
def _setup(self):
conf = get_default_sess_config()
self.server = tf.train.Server(
self.cluster, job_name=self.job_name,
task_index=self.task_index,
config=conf # TODO sessconfig
)
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index))
self.server.join()
return
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()):
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
devices=self.raw_devices,
var_strategy='replicated')
# (g, v) to be applied, where v is global (ps vars)
new_tower_grads = []
for i, grad_and_vars in enumerate(zip(*grad_list)):
# Ngpu * 2
with tf.device(self.raw_devices[i % self.nr_gpu]):
v = grad_and_vars[0][1]
if self.nr_gpu > 1:
# average gradient
all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / self.nr_gpu)
else:
grad = grad_and_vars[0][0]
with tf.device(self.param_server_device):
my_name = PS_SHADOW_VAR_PREFIX + '/' + v.name
if my_name.endswith(':0'):
my_name = my_name[:-2]
new_v = tf.get_variable(my_name, dtype=v.dtype.base_dtype,
initializer=v.initial_value,
trainable=True)
new_tower_grads.append((grad, new_v))
# apply gradients TODO do this for each variable separately?
opt = self.model.get_optimizer()
apply_gradient_op = opt.apply_gradients(new_tower_grads)
barrier = self.add_sync_queues_and_barrier('replicate_variable', [apply_gradient_op])
var_update_ops = []
with tf.control_dependencies([barrier]), \
tf.device(self.cpu_device):
for idx, (grad, v) in enumerate(new_tower_grads):
updated_value = v.read_value()
for towerid in range(self.nr_gpu):
logger.info("Step update {} -> {}".format(v.name, grad_list[towerid][idx][1].name))
var_update_ops.append(
grad_list[towerid][idx][1].assign(updated_value))
self.main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch])
self.post_init_op = self.get_post_init_ops()
def setup(self):
with tf.device(self.param_server_device):
gs = get_global_step_var()
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self)
self._setup()
self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors)
describe_model()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
#if not self.is_chief:
#self._callbacks = [ProgressBar()]
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
#local_init_op = tf.local_variables_initializer()
global_init_op = tf.global_variables_initializer()
logger.info("Finalize the graph, create the session ...")
self.sv = tf.train.Supervisor(
is_chief=self.is_chief,
logdir=None,
saver=None,
global_step=gs,
summary_op=None,
save_model_secs=0,
#local_init_op=local_init_op,
#ready_for_local_init_op=None,
summary_writer=None)
conf = get_default_sess_config()
sess = self.sv.prepare_or_wait_for_session(
master=self.server.target,
config=conf,
start_standard_services=False)
self.sess = sess
if self.is_chief:
print([k.name for k in tf.global_variables()])
sess.run(global_init_op)
logger.info("Global variables initialized.")
#sess.run(local_init_op)
#if self.is_chief:
#self.config.session_init.init(self.sess)
#self.sess.graph.finalize()
#else:
#logger.info("Worker {} waiting for chief".format(self.task_index))
#self.sess = tf.train.WorkerSessionCreator(master=self.server.target).create_session()
#logger.info("Worker wait finished")
#self.sess.run(local_init_op)
#logger.info("local init op runned")
logger.info("Running post init op...")
sess.run(self.post_init_op)
logger.info("Post init op finished.")
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
#self._monitored_sess = self.sv
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list):
"""Adds ops to enqueue on all worker queues.
Args:
name_prefix: prefixed for the shared_name of ops.
enqueue_after_list: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
num_workers = self.cluster.num_tasks('worker')
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(num_workers, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name_prefix, i))
for i in range(num_workers)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can
# finish this step.
token = tf.constant(False)
with tf.control_dependencies(enqueue_after_list):
for i, q in enumerate(sync_queues):
if i == self.task_index:
queue_ops.append(tf.no_op())
else:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops)
def get_post_init_ops(self):
# Copy initialized variables for variables on the parameter server
# to the local copy of the variable.
def strip_port(s):
if s.endswith(':0'):
return s[:-2]
return s
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
post_init_ops = []
for v in tf.global_variables():
if v.name.startswith(PS_SHADOW_VAR_PREFIX + '/'):
prefix = strip_port(
v.name[len(PS_SHADOW_VAR_PREFIX + '/'):])
for i in range(self.nr_gpu):
if i == 0:
name = prefix
else:
name = 'tower%s/%s' % (i, prefix)
if name in local_var_by_name:
copy_to = local_var_by_name[name]
logger.info("Post Init {} -> {}".format(v.name, copy_to.name))
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global var {} doesn't match local var".format(v.name))
return tf.group(*post_init_ops, name='post_init_ops')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment