Commit 476d7cf0 authored by Yuxin Wu's avatar Yuxin Wu

fix optimizer device. try cifar

parent 2cfefc90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-dist.py
import numpy as np
import os
import sys
import argparse
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
# Just import everything into current namespace
from tensorpack import *
import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf
IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
"""This function should build the model which takes the input variables
and define self.cost at the end"""
# inputs contains a list of input variables defined above
image, label = inputs
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
logits = (LinearWrap(image)
.Conv2D('conv0')
.MaxPooling('pool0', 2)
.Conv2D('conv1')
.Conv2D('conv2')
.MaxPooling('pool1', 2)
.Conv2D('conv3')
.FullyConnected('fc0', 512, nl=tf.nn.relu)
.Dropout('dropout', 0.5)
.FullyConnected('fc1', out_dim=10, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbf.prediction_incorrect(logits, label, name='incorrect')
accuracy = symbf.accuracy(logits, label, name='accuracy')
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error, accuracy)
# Use a regex to find parameters to apply weight decay.
# Here we apply a weight decay on all W (weight matrix) of all fc layers
wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']))
def _get_optimizer(self):
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
# This will also put the summary in tensorboard, stat.json and print in terminal
# but this time without moving average
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr)
def get_data():
train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True)
return train, test
def get_config():
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir('k')
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch = dataset_train.size()
# get the config which contains everything necessary in a training
return TrainConfig(
model=Model(),
dataflow=dataset_train, # the DataFlow instance for training
callbacks=[
#ModelSaver(), # save the model after every epoch
#MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
#InferenceRunner( # run inference(for validation) after every epoch
#dataset_test, # the DataFlow instance used for validation
## Calculate both the cost and the error for this DataFlow
#[ScalarStats('cross_entropy_loss'), ScalarStats('accuracy'),
#ClassificationError('incorrect')]),
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--job', required=True)
parser.add_argument('--task', type=int, default=0)
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
cluster_spec = tf.train.ClusterSpec({
'ps': ['0.0.0.0:2222'],
'worker': ['0.0.0.0:2223', '0.0.0.0:2224']
})
config.data = QueueInput(config.dataflow)
DistributedReplicatedTrainer(config, args.job, args.task, cluster_spec).train()
...@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam): ...@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam):
self.var = v self.var = v
break break
else: else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name)) raise ValueError("{} is not a GLOBAL_VARIABLE in the graph!".format(self.var_name))
def set_value(self, v): def set_value(self, v):
""" Assign the variable a new value. """ """ Assign the variable a new value. """
......
...@@ -90,6 +90,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -90,6 +90,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
logger.info("Running ps {}".format(self.task_index)) logger.info("Running ps {}".format(self.task_index))
self.server.join() self.server.join()
return return
opt = self.model.get_optimizer() # in global scope, not local
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()): custom_getter=OverrideToLocalVariableIfNotPsVar()):
...@@ -126,25 +127,28 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -126,25 +127,28 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
new_tower_grads.append((grad, new_v)) new_tower_grads.append((grad, new_v))
# apply gradients TODO do this for each variable separately? # apply gradients TODO do this for each variable separately?
opt = self.model.get_optimizer()
apply_gradient_op = opt.apply_gradients(new_tower_grads)
barrier = self.add_sync_queues_and_barrier('replicate_variable', [apply_gradient_op])
var_update_ops = [] var_update_ops = []
with tf.control_dependencies([barrier]), \ with tf.device(self.param_server_device):
tf.device(self.cpu_device): for vid, (g, v) in enumerate(new_tower_grads):
for idx, (grad, v) in enumerate(new_tower_grads): apply_gradient_op = opt.apply_gradients([(g, v)])
updated_value = v.read_value() barrier = self.add_sync_queues_and_barrier(
for towerid in range(self.nr_gpu): 'param_update_barrier_{}'.format(vid), [apply_gradient_op])
logger.info("Step update {} -> {}".format(v.name, grad_list[towerid][idx][1].name)) with tf.control_dependencies([barrier]), \
var_update_ops.append( tf.device(self.cpu_device):
grad_list[towerid][idx][1].assign(updated_value)) updated_value = v.read_value()
for towerid in range(self.nr_gpu):
logger.info("Step update {} -> {}".format(v.name, grad_list[towerid][vid][1].name))
var_update_ops.append(
grad_list[towerid][vid][1].assign(updated_value))
self.main_fetch = tf.group(*var_update_ops, name='main_fetches') self.main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch]) self.train_op = self.main_fetch
#self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch])
self.post_init_op = self.get_post_init_ops() self.post_init_op = self.get_post_init_ops()
def setup(self): def setup(self):
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
gs = get_global_step_var() gs = get_global_step_var()
opt = self.model.get_optimizer() # in global scope, not local
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self) self._input_source.setup_training(self)
...@@ -153,17 +157,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -153,17 +157,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors) self.register_callback(self.monitors)
describe_model() describe_model()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
#if not self.is_chief:
#self._callbacks = [ProgressBar()]
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
#local_init_op = tf.local_variables_initializer()
global_init_op = tf.global_variables_initializer()
logger.info("Finalize the graph, create the session ...") logger.info("Finalize the graph, create the session ...")
self.sv = tf.train.Supervisor( self.sv = tf.train.Supervisor(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment