Commit e072d909 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] reorganize trainer. fix batch_norm

parent 335d6c28
...@@ -8,14 +8,9 @@ import tensorflow as tf ...@@ -8,14 +8,9 @@ import tensorflow as tf
import argparse import argparse
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack import *
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
""" """
CIFAR10-resnet example. CIFAR10-resnet example.
...@@ -186,11 +181,9 @@ if __name__ == '__main__': ...@@ -186,11 +181,9 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default(): config = get_config()
with tf.device('/cpu:0'): if args.load:
config = get_config() config.session_init = SaverRestore(args.load)
if args.load: if args.gpu:
config.session_init = SaverRestore(args.load) config.nr_tower = len(args.gpu.split(','))
if args.gpu: SyncMultiGPUTrainer(config).train()
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import tensorflow as tf import tensorflow as tf
from copy import copy from copy import copy
import re
from ..utils import logger
from ._common import layer_register from ._common import layer_register
__all__ = ['BatchNorm'] __all__ = ['BatchNorm']
...@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else: else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay) emaname = 'EMA'
ema_apply_op = ema.apply([batch_mean, batch_var]) if not batch_mean.name.startswith('towerp'):
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
assert not use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G = tf.get_default_graph()
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
...@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn') x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else: else:
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
# XXX TODO batch==1?
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_normalization( return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn') x, mean, var, beta, gamma, epsilon, 'bn')
...@@ -5,13 +5,19 @@ ...@@ -5,13 +5,19 @@
from ..utils.naming import * from ..utils.naming import *
import tensorflow as tf import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step',
'get_global_step_var', 'get_global_step_var',
'get_op_var_name', 'get_op_var_name',
'get_vars_by_names' 'get_vars_by_names',
] 'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def get_default_sess_config(mem_fraction=0.9): def get_default_sess_config(mem_fraction=0.9):
""" """
...@@ -66,3 +72,24 @@ def get_vars_by_names(names): ...@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn, varn = get_op_var_name(n) opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
return ret return ret
def backup_collection(keys):
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
...@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal ...@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer'] __all__ = ['Trainer']
...@@ -141,7 +140,6 @@ class Trainer(object): ...@@ -141,7 +140,6 @@ class Trainer(object):
self.sess.close() self.sess.close()
def init_session_and_coord(self): def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import tensorflow as tf import tensorflow as tf
import threading import threading
import time import time
import copy
import re import re
import functools import functools
from six.moves import zip from six.moves import zip
...@@ -15,6 +14,7 @@ from ..dataflow.common import RepeatedData ...@@ -15,6 +14,7 @@ from ..dataflow.common import RepeatedData
from ..utils import * from ..utils import *
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import * from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer', __all__ = ['SimpleTrainer', 'QueueInputTrainer',
...@@ -42,6 +42,7 @@ class SimpleTrainer(Trainer): ...@@ -42,6 +42,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op) avg_maintain_op)
self.init_session_and_coord() self.init_session_and_coord()
describe_model()
# create an infinte data producer # create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data() self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop() self.main_loop()
...@@ -100,14 +101,11 @@ class EnqueueThread(threading.Thread): ...@@ -100,14 +101,11 @@ class EnqueueThread(threading.Thread):
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" """ Single GPU Trainer, takes input from a queue"""
Trainer which builds a FIFO queue for input.
Support multi GPU. SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
"""
def __init__(self, config, input_queue=None, predict_tower=None):
def __init__(self, config, input_queue=None,
async=False,
predict_tower=None):
""" """
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
...@@ -120,27 +118,11 @@ class QueueInputTrainer(Trainer): ...@@ -120,27 +118,11 @@ class QueueInputTrainer(Trainer):
100, [x.dtype for x in self.input_vars], name='input_queue') 100, [x.dtype for x in self.input_vars], name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
self.async = async
if self.async:
assert self.config.nr_tower > 1
self.dequed_inputs = []
if predict_tower is None: if predict_tower is None:
# by default, only use first training tower for prediction # by default, use first training tower for prediction
predict_tower = [0] predict_tower = [0]
self.predict_tower = predict_tower self.predict_tower = predict_tower
self.dequed_inputs = None
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _get_model_inputs(self): def _get_model_inputs(self):
""" Dequeue a datapoint from input_queue and return""" """ Dequeue a datapoint from input_queue and return"""
...@@ -150,42 +132,111 @@ class QueueInputTrainer(Trainer): ...@@ -150,42 +132,111 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars) assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars): for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
self.dequed_inputs.append(ret)
return ret return ret
def _build_predict_tower(self): def _build_predict_tower(self):
inputs = self.model.get_input_vars() inputs = self.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.predict_tower: for k in self.predict_tower:
logger.info("Building graph for predict tower 0{}...".format(k)) logger.info("Building graph for predict towerp{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \ with tf.device('/gpu:{}'.format(k)), \
tf.name_scope('tower0{}'.format(k)): tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False) self.model.build_graph(inputs, False)
tf.get_variable_scope().reuse_variables()
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower case""" """ Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs() self.dequed_inputs = model_inputs = self._get_model_inputs()
self.model.build_graph(model_inputs, True) self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
return grads return grads
def _build_enque_thread(self):
# create a thread that keeps filling the queue
enqueue_op = self.input_queue.enqueue(self.input_vars)
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th)
def train(self):
assert self.config.nr_tower == 1, "QueueInputTrainer only supports 1 tower!"
self.init_session_and_coord()
self._build_enque_thread()
grads = self._single_tower_grad()
grads = self.process_grads(grads)
describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
self.main_loop()
def run_step(self):
""" just run self.train_op"""
self.sess.run([self.train_op])
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
:returns: a predictor function
"""
tower = self.predict_tower[tower % len(self.predict_tower)]
raw_input_vars = get_vars_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(raw_input_vars)
feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def get_predict_funcs(self, input_names, output_names, n):
""" return n predicts functions evenly on each predict_tower"""
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert config.nr_tower > 1
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _multi_tower_grads(self): def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower)) logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device
collect_dedup = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
grad_list = [] grad_list = []
for i in range(self.config.nr_tower): for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for training tower {}...".format(i)) logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True) self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower cost_var = self.model.get_cost() # build tower
...@@ -196,103 +247,79 @@ class QueueInputTrainer(Trainer): ...@@ -196,103 +247,79 @@ class QueueInputTrainer(Trainer):
if i == 0: if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in collect_dedup: # avoid repeated summary from each device
kept_summaries[k] = copy.copy(tf.get_collection(k)) backup = backup_collection(self.SUMMARY_BACKUP_KEYS)
for k in collect_dedup: restore_collection(backup)
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(kept_summaries[k])
return grad_list return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self): def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars) self.init_session_and_coord()
self._build_enque_thread()
self._build_predict_tower() grad_list = self._multi_tower_grads()
if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads() grads = MultiGPUTrainer._average_grads(grad_list)
if not self.async: grads = self.process_grads(grads)
grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
else:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
else:
grads = self._single_tower_grad()
grads = self.process_grads(grads)
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average()) summary_moving_average())
describe_model()
if self.async: self._build_predict_tower()
# prepare train_op for the rest of the towers
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue self._build_enque_thread()
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th) grad_list = self._multi_tower_grads()
# do nothing in training # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
# prepare train_op for the rest of the towers
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0] #self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop() self.main_loop()
def run_step(self): def run_step(self):
if self.async: if not self.async_running:
if not self.async_running: self.async_running = True
self.async_running = True for th in self.threads: # resume all threads
for th in self.threads: # resume all threads th.resume()
th.resume()
self.sess.run([self.train_op]) # faster since train_op return None self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self): def _trigger_epoch(self):
# note that summary_op will take a data from the queue self.async_running = False
if self.async: for th in self.threads:
self.async_running = False th.pause()
for th in self.threads:
th.pause()
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
"""
tower = self.predict_tower[tower % len(self.predict_tower)]
if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower0{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names)
if self.config.nr_tower > 1:
output_names = ['tower0{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(raw_input_vars)
feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
def AsyncMultiGPUTrainer(config):
return QueueInputTrainer(config, async=True)
def SyncMultiGPUTrainer(config):
return QueueInputTrainer(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment