Commit e072d909 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] reorganize trainer. fix batch_norm

parent 335d6c28
......@@ -8,14 +8,9 @@ import tensorflow as tf
import argparse
import os
from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
"""
CIFAR10-resnet example.
......@@ -186,11 +181,9 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
with tf.device('/cpu:0'):
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
SyncMultiGPUTrainer(config).train()
......@@ -5,7 +5,9 @@
import tensorflow as tf
from copy import copy
import re
from ..utils import logger
from ._common import layer_register
__all__ = ['BatchNorm']
......@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay)
emaname = 'EMA'
if not batch_mean.name.startswith('towerp'):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
assert not use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G = tf.get_default_graph()
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
if use_local_stat:
with tf.control_dependencies([ema_apply_op]):
......@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else:
batch = tf.cast(tf.shape(x)[0], tf.float32)
# XXX TODO batch==1?
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn')
......@@ -5,13 +5,19 @@
from ..utils.naming import *
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['get_default_sess_config',
'get_global_step',
'get_global_step_var',
'get_op_var_name',
'get_vars_by_names'
]
'get_vars_by_names',
'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def get_default_sess_config(mem_fraction=0.9):
"""
......@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn))
return ret
def backup_collection(keys):
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
......@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder
from ..tfutils import *
from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer']
......@@ -141,7 +140,6 @@ class Trainer(object):
self.sess.close()
def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
......
......@@ -5,7 +5,6 @@
import tensorflow as tf
import threading
import time
import copy
import re
import functools
from six.moves import zip
......@@ -15,6 +14,7 @@ from ..dataflow.common import RepeatedData
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer',
......@@ -42,6 +42,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op)
self.init_session_and_coord()
describe_model()
# create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop()
......@@ -100,14 +101,11 @@ class EnqueueThread(threading.Thread):
logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer):
"""
Trainer which builds a FIFO queue for input.
Support multi GPU.
"""
""" Single GPU Trainer, takes input from a queue"""
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
def __init__(self, config, input_queue=None,
async=False,
predict_tower=None):
def __init__(self, config, input_queue=None, predict_tower=None):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
......@@ -120,27 +118,11 @@ class QueueInputTrainer(Trainer):
100, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
self.async = async
if self.async:
assert self.config.nr_tower > 1
self.dequed_inputs = []
if predict_tower is None:
# by default, only use first training tower for prediction
# by default, use first training tower for prediction
predict_tower = [0]
self.predict_tower = predict_tower
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
self.dequed_inputs = None
def _get_model_inputs(self):
""" Dequeue a datapoint from input_queue and return"""
......@@ -150,42 +132,111 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape())
self.dequed_inputs.append(ret)
return ret
def _build_predict_tower(self):
inputs = self.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.predict_tower:
logger.info("Building graph for predict tower 0{}...".format(k))
logger.info("Building graph for predict towerp{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \
tf.name_scope('tower0{}'.format(k)):
tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False)
tf.get_variable_scope().reuse_variables()
def _single_tower_grad(self):
""" Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs()
self.dequed_inputs = model_inputs = self._get_model_inputs()
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
return grads
def _build_enque_thread(self):
# create a thread that keeps filling the queue
enqueue_op = self.input_queue.enqueue(self.input_vars)
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th)
def train(self):
assert self.config.nr_tower == 1, "QueueInputTrainer only supports 1 tower!"
self.init_session_and_coord()
self._build_enque_thread()
grads = self._single_tower_grad()
grads = self.process_grads(grads)
describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
self.main_loop()
def run_step(self):
""" just run self.train_op"""
self.sess.run([self.train_op])
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
:returns: a predictor function
"""
tower = self.predict_tower[tower % len(self.predict_tower)]
raw_input_vars = get_vars_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(raw_input_vars)
feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def get_predict_funcs(self, input_names, output_names, n):
""" return n predicts functions evenly on each predict_tower"""
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert config.nr_tower > 1
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device
collect_dedup = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
......@@ -196,23 +247,38 @@ class QueueInputTrainer(Trainer):
if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
for k in collect_dedup:
kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(kept_summaries[k])
# avoid repeated summary from each device
backup = backup_collection(self.SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars)
self.init_session_and_coord()
self._build_enque_thread()
self._build_predict_tower()
if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads()
if not self.async:
grads = QueueInputTrainer._average_grads(grad_list)
grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
else:
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
......@@ -220,15 +286,12 @@ class QueueInputTrainer(Trainer):
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
else:
grads = self._single_tower_grad()
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
if self.async:
# prepare train_op for the rest of the towers
self.threads = []
for k in range(1, self.config.nr_tower):
......@@ -240,17 +303,13 @@ class QueueInputTrainer(Trainer):
self.threads.append(th)
self.async_running = False
self._build_predict_tower()
self.init_session_and_coord()
# create a thread that keeps filling the queue
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th)
# do nothing in training
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
def run_step(self):
if self.async:
if not self.async_running:
self.async_running = True
for th in self.threads: # resume all threads
......@@ -258,41 +317,9 @@ class QueueInputTrainer(Trainer):
self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self):
# note that summary_op will take a data from the queue
if self.async:
self.async_running = False
for th in self.threads:
th.pause()
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
"""
tower = self.predict_tower[tower % len(self.predict_tower)]
if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower0{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names)
if self.config.nr_tower > 1:
output_names = ['tower0{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(raw_input_vars)
feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
def AsyncMultiGPUTrainer(config):
return QueueInputTrainer(config, async=True)
def SyncMultiGPUTrainer(config):
return QueueInputTrainer(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment