Commit 977134e1 authored by Yuxin Wu's avatar Yuxin Wu

after_train

parent 696a7db7
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: loyaltry.py # File: example_cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -16,10 +16,15 @@ from tensorpack.utils.symbolic_functions import * ...@@ -16,10 +16,15 @@ from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
from cifar10 import cifar10
"""
This config follows the same preprocessing/model/hyperparemeters as in
tensorflow cifar10 examples. (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/)
But it's faster.
"""
BATCH_SIZE = 128 BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 20000 # a large number, as in the official example MIN_AFTER_DEQUEUE = int(50000 * 0.4)
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training): def get_model(inputs, is_training):
...@@ -50,7 +55,7 @@ def get_model(inputs, is_training): ...@@ -50,7 +55,7 @@ def get_model(inputs, is_training):
l = FullyConnected('fc1', l, out_dim=192, l = FullyConnected('fc1', l, out_dim=192,
W_init=tf.truncated_normal_initializer(stddev=0.04), W_init=tf.truncated_normal_initializer(stddev=0.04),
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
## fc will have activation summary by default. disable this for the output layer # fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, summary_activation=False, logits = FullyConnected('linear', l, out_dim=10, summary_activation=False,
nl=tf.identity, nl=tf.identity,
W_init=tf.truncated_normal_initializer(stddev=1.0/192)) W_init=tf.truncated_normal_initializer(stddev=1.0/192))
...@@ -91,14 +96,14 @@ def get_config(): ...@@ -91,14 +96,14 @@ def get_config():
Flip(horiz=True), Flip(horiz=True),
BrightnessAdd(63), BrightnessAdd(63),
Contrast((0.2,1.8)), Contrast((0.2,1.8)),
PerImageWhitening(all_channel=True) MeanVarianceNormalize(all_channel=True)
] ]
dataset_train = AugmentImageComponent(dataset_train, augmentors) dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
augmentors = [ augmentors = [
CenterCrop((24, 24)), CenterCrop((24, 24)),
PerImageWhitening(all_channel=True) MeanVarianceNormalize(all_channel=True)
] ]
dataset_test = dataset.Cifar10('test') dataset_test = dataset.Cifar10('test')
dataset_test = AugmentImageComponent(dataset_test, augmentors) dataset_test = AugmentImageComponent(dataset_test, augmentors)
...@@ -107,7 +112,6 @@ def get_config(): ...@@ -107,7 +112,6 @@ def get_config():
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
sess_config.device_count['GPU'] = 2
# prepare model # prepare model
input_vars = [ input_vars = [
...@@ -150,6 +154,8 @@ if __name__ == '__main__': ...@@ -150,6 +154,8 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default(): with tf.Graph().as_default():
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
......
...@@ -31,13 +31,15 @@ class Callback(object): ...@@ -31,13 +31,15 @@ class Callback(object):
Called before starting iterative training Called before starting iterative training
""" """
def trigger_step(self, inputs, outputs, cost): def after_train(self):
"""
Called after training
"""
def trigger_step(self):
""" """
Callback to be triggered after every step (every backpropagation) Callback to be triggered after every step (every backpropagation)
Args: Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
inputs: the list of input values
outputs: list of output values after running this inputs
cost: the cost value after running this input
""" """
def trigger_epoch(self): def trigger_epoch(self):
......
...@@ -56,3 +56,6 @@ class SummaryWriter(Callback): ...@@ -56,3 +56,6 @@ class SummaryWriter(Callback):
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value)) logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
self.writer.add_summary(summary, get_global_step()) self.writer.add_summary(summary, get_global_step())
def after_train(self):
self.writer.close()
...@@ -78,9 +78,13 @@ class TrainCallbacks(Callback): ...@@ -78,9 +78,13 @@ class TrainCallbacks(Callback):
cb.before_train() cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
def trigger_step(self, inputs, outputs, cost): def after_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.trigger_step(inputs, outputs, cost) cb.after_train()
def trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
def trigger_epoch(self): def trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
...@@ -111,6 +115,10 @@ class TestCallbacks(Callback): ...@@ -111,6 +115,10 @@ class TestCallbacks(Callback):
for cb in self.cbs: for cb in self.cbs:
cb.before_train() cb.before_train()
def after_train(self):
for cb in self.cbs:
cb.after_train()
def trigger_epoch(self): def trigger_epoch(self):
if not self.cbs: if not self.cbs:
return return
...@@ -153,8 +161,12 @@ class Callbacks(Callback): ...@@ -153,8 +161,12 @@ class Callbacks(Callback):
self.train.before_train() self.train.before_train()
self.test.before_train() self.test.before_train()
def trigger_step(self, inputs, outputs, cost): def after_train(self):
self.train.trigger_step(inputs, outputs, cost) self.train.after_train()
self.test.after_train()
def trigger_step(self):
self.train.trigger_step()
# test callback don't have trigger_step # test callback don't have trigger_step
def trigger_epoch(self): def trigger_epoch(self):
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from .base import ImageAugmentor from .base import ImageAugmentor
import numpy as np import numpy as np
__all__ = ['BrightnessAdd', 'Contrast', 'PerImageWhitening'] __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
class BrightnessAdd(ImageAugmentor): class BrightnessAdd(ImageAugmentor):
""" """
...@@ -35,7 +35,7 @@ class Contrast(ImageAugmentor): ...@@ -35,7 +35,7 @@ class Contrast(ImageAugmentor):
img.arr = (arr - mean) * r + mean img.arr = (arr - mean) * r + mean
img.arr = np.clip(img.arr, 0, 255) img.arr = np.clip(img.arr, 0, 255)
class PerImageWhitening(ImageAugmentor): class MeanVarianceNormalize(ImageAugmentor):
""" """
Linearly scales image to have zero mean and unit norm. Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev x = (x - mean) / adjusted_stddev
...@@ -43,7 +43,6 @@ class PerImageWhitening(ImageAugmentor): ...@@ -43,7 +43,6 @@ class PerImageWhitening(ImageAugmentor):
""" """
def __init__(self, all_channel=True): def __init__(self, all_channel=True):
self.all_channel = all_channel self.all_channel = all_channel
pass
def _augment(self, img): def _augment(self, img):
if self.all_channel: if self.all_channel:
......
...@@ -10,7 +10,7 @@ import argparse ...@@ -10,7 +10,7 @@ import argparse
import tqdm import tqdm
from utils import * from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard from utils.concurrency import EnqueueThread
from callbacks import * from callbacks import *
from utils.summary import summary_moving_average from utils.summary import summary_moving_average
from utils.modelutils import describe_model from utils.modelutils import describe_model
...@@ -75,29 +75,12 @@ class TrainConfig(object): ...@@ -75,29 +75,12 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_gradients(tower_grads): def average_gradients(tower_grads):
average_grads = [] average_grads = []
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following: grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) v = grad_and_vars[0][1]
grads = [] average_grads.append((grad, v))
for g, _ in grad_and_vars: return average_grads
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension.
grad = tf.concat(0, grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
...@@ -139,17 +122,17 @@ def start_train(config): ...@@ -139,17 +122,17 @@ def start_train(config):
kept_summaries = {} kept_summaries = {}
grads = [] grads = []
for i in range(config.nr_tower): for i in range(config.nr_tower):
with tf.device('/gpu:{}'.format(i)): with tf.device('/gpu:{}'.format(i)), \
with tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True) output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
grads.append( grads.append(
config.optimizer.compute_gradients(cost_var)) config.optimizer.compute_gradients(cost_var))
if i == 0: if i == 0:
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in coll_keys: for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k)) kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in coll_keys: # avoid repeating summary on multiple devices for k in coll_keys: # avoid repeating summary on multiple devices
del tf.get_collection(k)[:] del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection(k).extend(kept_summaries[k])
...@@ -172,29 +155,31 @@ def start_train(config): ...@@ -172,29 +155,31 @@ def start_train(config):
# start training: # start training:
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
# a thread that keeps filling the queue # a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue)
model_th = tf.train.start_queue_runners( model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True) sess=sess, coord=coord, daemon=True, start=True)
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue)
input_th.start() input_th.start()
with sess.as_default(), \ with sess.as_default():
coordinator_guard(sess, coord): try:
logger.info("Start with global_step={}".format(get_global_step())) logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train() callbacks.before_train()
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for step in tqdm.trange( for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.2): config.step_per_epoch, leave=True, mininterval=0.2):
if coord.should_stop(): if coord.should_stop():
return return
# TODO if no one uses trigger_step, train_op can be sess.run([train_op]) # faster since train_op return None
# faster, see: https://github.com/soumith/convnet-benchmarks/pull/67/files callbacks.trigger_step()
fetches = [train_op, cost_var] + output_vars + model_inputs
results = sess.run(fetches) # note that summary_op will take a data from the queue.
cost = results[1] callbacks.trigger_epoch()
outputs = results[2:2 + len(output_vars)] except (KeyboardInterrupt, Exception):
inputs = results[-len(model_inputs):] raise
callbacks.trigger_step(inputs, outputs, cost) finally:
coord.request_stop()
# note that summary_op will take a data from the queue. queue.close(cancel_pending_enqueues=True)
callbacks.trigger_epoch() callbacks.after_train()
sess.close()
...@@ -37,8 +37,7 @@ def get_default_sess_config(): ...@@ -37,8 +37,7 @@ def get_default_sess_config():
Tensorflow default session config consume too much resources Tensorflow default session config consume too much resources
""" """
conf = tf.ConfigProto() conf = tf.ConfigProto()
conf.device_count['GPU'] = 1 conf.gpu_options.per_process_gpu_memory_fraction = 0.6
conf.gpu_options.per_process_gpu_memory_fraction = 0.8
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.allow_soft_placement = True conf.allow_soft_placement = True
return conf return conf
......
...@@ -51,13 +51,3 @@ class EnqueueThread(threading.Thread): ...@@ -51,13 +51,3 @@ class EnqueueThread(threading.Thread):
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
self.queue.close(cancel_pending_enqueues=True) self.queue.close(cancel_pending_enqueues=True)
self.coord.request_stop() self.coord.request_stop()
@contextmanager
def coordinator_guard(sess, coord):
try:
yield
except (KeyboardInterrupt, Exception) as e:
raise
finally:
coord.request_stop()
sess.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment