Commit c279dbfe authored by Yuxin Wu's avatar Yuxin Wu

use _setup and small other refactors

parent 82b418fd
...@@ -12,7 +12,23 @@ from .symbolic_functions import rms ...@@ -12,7 +12,23 @@ from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient'] 'ScaleGradient', 'MapGradient', 'apply_grad_processors']
def apply_grad_processors(grads, gradprocs):
"""
:param grads: list of (grad, var).
:param gradprocs: list of `GradientProcessor` instances.
:returns: list of (grad, var) went through the processors
"""
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
for proc in gradprocs:
g = proc.process(g)
return g
class GradientProcessor(object): class GradientProcessor(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -98,12 +114,14 @@ class CheckGradient(MapGradient): ...@@ -98,12 +114,14 @@ class CheckGradient(MapGradient):
class ScaleGradient(MapGradient): class ScaleGradient(MapGradient):
""" """
Scale gradient by a multiplier Scale certain gradient by a multiplier
""" """
def __init__(self, multipliers): def __init__(self, multipliers):
""" """
:param multipliers: list of (regex, float) :param multipliers: list of (regex, float)
""" """
if not isinstance(multipliers, list):
multipliers = [multipliers]
self.multipliers = multipliers self.multipliers = multipliers
super(ScaleGradient, self).__init__(self._mapper) super(ScaleGradient, self).__init__(self._mapper)
......
...@@ -98,11 +98,11 @@ def add_moving_summary(v, *args): ...@@ -98,11 +98,11 @@ def add_moving_summary(v, *args):
v = [v] v = [v]
v.extend(args) v.extend(args)
for x in v: for x in v:
assert x.get_shape().ndims == 0
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
def summary_moving_average(): def summary_moving_average():
""" Create a MovingAverage op and summary for all variables in """ Create a MovingAverage op and summary for all variables in MOVING_SUMMARY_VARS_KEY.
MOVING_SUMMARY_VARS_KEY.
:returns: a op to maintain these average. :returns: a op to maintain these average.
""" """
with tf.name_scope('EMA_summary'): with tf.name_scope('EMA_summary'):
...@@ -113,7 +113,6 @@ def summary_moving_average(): ...@@ -113,7 +113,6 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY) vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary): for idx, c in enumerate(vars_to_summary):
# TODO assert scalar
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c)) tf.scalar_summary(name, averager.average(c))
return avg_maintain_op return avg_maintain_op
......
...@@ -45,10 +45,10 @@ class Trainer(object): ...@@ -45,10 +45,10 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
@abstractmethod
def train(self): def train(self):
""" Start training""" """ Start training"""
pass self.setup()
self.main_loop()
@abstractmethod @abstractmethod
def run_step(self): def run_step(self):
...@@ -92,7 +92,8 @@ class Trainer(object): ...@@ -92,7 +92,8 @@ class Trainer(object):
create_summary(name, val), get_global_step()) create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val) self.stat_holder.add_stat(name, val)
def finalize(self): def setup(self):
self._setup()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
...@@ -112,8 +113,11 @@ class Trainer(object): ...@@ -112,8 +113,11 @@ class Trainer(object):
tf.train.start_queue_runners( tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
@abstractmethod
def _setup(self):
""" setup Trainer-specific stuff for training"""
def main_loop(self): def main_loop(self):
self.finalize()
callbacks = self.config.callbacks callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
try: try:
...@@ -139,16 +143,3 @@ class Trainer(object): ...@@ -139,16 +143,3 @@ class Trainer(object):
self.coord.request_stop() self.coord.request_stop()
self.summary_writer.close() self.summary_writer.close()
self.sess.close() self.sess.close()
def process_grads(self, grads):
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
procs = self.config.model.get_gradient_processor()
for proc in procs:
g = proc.process(g)
return g
...@@ -14,6 +14,7 @@ from ..tfutils.summary import summary_moving_average ...@@ -14,6 +14,7 @@ from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors
from .trainer import QueueInputTrainer from .trainer import QueueInputTrainer
...@@ -32,11 +33,16 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -32,11 +33,16 @@ class MultiGPUTrainer(QueueInputTrainer):
with tf.name_scope('AvgGrad'): with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
for x in grad_and_vars: all_grad = [k[0] for k in grad_and_vars]
assert x[0] is not None, \
"Gradient w.r.t {} is None!".format(v.name) nones = list(set(all_grad))
if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
continue
try: try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads)) grad = tf.add_n(all_grad) / float(len(tower_grads))
except: except:
logger.error("Error while processing gradients of {}".format(v.name)) logger.error("Error while processing gradients of {}".format(v.name))
raise raise
...@@ -44,8 +50,7 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -44,8 +50,7 @@ class MultiGPUTrainer(QueueInputTrainer):
return ret return ret
def _multi_tower_grads(self): def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format( logger.info("Training a model of {} tower".format(len(self.config.tower)))
len(self.config.tower)))
grad_list = [] grad_list = []
global_scope = tf.get_variable_scope() global_scope = tf.get_variable_scope()
...@@ -60,59 +65,54 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -60,59 +65,54 @@ class MultiGPUTrainer(QueueInputTrainer):
self.model.build_graph(model_inputs) self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster? # TODO gate_gradienst=0 might be faster?
grad_list.append( grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0)) self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if idx == 0: if idx == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) add_moving_summary(cost_var)
# avoid repeated summary from each device # avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS) backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup) restore_collection(backup)
return grad_list return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer): class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self): def _setup(self):
self._build_enque_thread() self._build_enque_thread()
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
grads = MultiGPUTrainer._average_grads(grad_list) grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads) grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
describe_model() describe_model()
# [debug]: do nothing in training # [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0] #self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer): class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self): def _setup(self):
self._build_enque_thread() self._build_enque_thread()
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
def scale(grads): gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower)))
with tf.name_scope('AsyncScaleGrad'): grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
return [(grad / len(self.config.tower) if grad is not None else None, var)
for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
# use grad from the first tower for iteration in main thread # use grad from the first tower for iteration in main thread
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grad_list[0], get_global_step_var()), self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
describe_model() describe_model()
self._start_async_threads(grad_list) self._start_async_threads(grad_list)
self.main_loop()
def _start_async_threads(self, grad_list): def _start_async_threads(self, grad_list):
# prepare train_op for the rest of the towers # prepare train_op for the rest of the towers
# itertools.count is atomic w.r.t. python threads # itertools.count is atomic w.r.t. python threads
...@@ -145,7 +145,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -145,7 +145,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
async_step_total_cnt = int(re.findall( async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0]) '[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary( self.write_scalar_summary(
'async_global_step', async_step_total_cnt) 'async-global-step', async_step_total_cnt)
except: except:
pass logger.exception("Cannot log async-global-step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch() super(AsyncMultiGPUTrainer, self)._trigger_epoch()
...@@ -18,6 +18,7 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary ...@@ -18,6 +18,7 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
...@@ -51,37 +52,39 @@ class PredictorFactory(object): ...@@ -51,37 +52,39 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope # build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
build_multi_tower_prediction_graph( build_multi_tower_prediction_graph(self.model, self.towers)
self.model, self.towers)
self.tower_built = True self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def __init__(self, config):
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
def run_step(self): def run_step(self):
data = next(self.data_producer) data = next(self.data_producer)
feed = dict(zip(self.input_vars, data)) feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self): def _setup(self):
model = self.model model = self.model
self.input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
with TowerContext(''): with TowerContext(''):
model.build_graph(self.input_vars) model.build_graph(self.input_vars)
cost_var = model.get_cost() # TODO assert scalar cost_var = model.get_cost()
add_moving_summary(cost_var) add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads) grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
avg_maintain_op = summary_moving_average()
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) summary_moving_average())
describe_model() describe_model()
# create an infinte data producer # create an infinte data producer
self.config.dataset.reset_state() self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data() self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop()
def _trigger_epoch(self): def _trigger_epoch(self):
if self.summary_op is not None: if self.summary_op is not None:
...@@ -91,14 +94,14 @@ class SimpleTrainer(Trainer): ...@@ -91,14 +94,14 @@ class SimpleTrainer(Trainer):
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
if not hasattr(self, 'predictor_factory'): return self._predictor_factory.get_predictor(input_names, output_names, 0)
self.predictor_factory = PredictorFactory(self.sess, self.model, [0])
return self.predictor_factory.get_predictor(input_names, output_names, 0)
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer): def __init__(self, trainer):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread' self.name = 'EnqueueThread'
self.daemon = True
self.sess = trainer.sess self.sess = trainer.sess
self.coord = trainer.coord self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1) self.dataflow = RepeatedData(trainer.config.dataset, -1)
...@@ -109,7 +112,8 @@ class EnqueueThread(threading.Thread): ...@@ -109,7 +112,8 @@ class EnqueueThread(threading.Thread):
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size() self.size_op = self.queue.size()
self.daemon = True add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
def run(self): def run(self):
self.dataflow.reset_state() self.dataflow.reset_state()
...@@ -155,7 +159,9 @@ class QueueInputTrainer(Trainer): ...@@ -155,7 +159,9 @@ class QueueInputTrainer(Trainer):
self.input_queue = input_queue self.input_queue = input_queue
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
self.predict_tower = predict_tower or [0] predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower)
self.dequed_inputs = None self.dequed_inputs = None
def _get_dequeued_inputs(self): def _get_dequeued_inputs(self):
...@@ -171,8 +177,6 @@ class QueueInputTrainer(Trainer): ...@@ -171,8 +177,6 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower""" """ Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_dequeued_inputs() self.dequed_inputs = model_inputs = self._get_dequeued_inputs()
add_moving_summary(tf.cast(
self.input_queue.size(), tf.float32, name='input-queue-size'))
# test the overhead of queue # test the overhead of queue
#with tf.device('/gpu:0'): #with tf.device('/gpu:0'):
...@@ -192,24 +196,22 @@ class QueueInputTrainer(Trainer): ...@@ -192,24 +196,22 @@ class QueueInputTrainer(Trainer):
self.input_th = EnqueueThread(self) self.input_th = EnqueueThread(self)
self.config.callbacks.append(StartProcOrThread(self.input_th)) self.config.callbacks.append(StartProcOrThread(self.input_th))
def train(self): def _setup(self):
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self._build_enque_thread() self._build_enque_thread()
grads = self._single_tower_grad() grads = self._single_tower_grad()
grads = self.process_grads(grads) grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
describe_model() describe_model()
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# skip training # skip training
#self.train_op = tf.group(*self.dequed_inputs) #self.train_op = tf.group(*self.dequed_inputs)
self.main_loop()
def run_step(self): def run_step(self):
""" Simply run self.train_op""" """ Simply run self.train_op"""
self.sess.run(self.train_op) self.sess.run(self.train_op)
...@@ -236,10 +238,7 @@ class QueueInputTrainer(Trainer): ...@@ -236,10 +238,7 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func :param tower: return the kth predict_func
:returns: an `OnlinePredictor` :returns: an `OnlinePredictor`
""" """
if not hasattr(self, 'predictor_factory'): return self._predictor_factory.get_predictor(input_names, output_names, tower)
self.predictor_factory = PredictorFactory(
self.sess, self.model, self.predict_tower)
return self.predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)] return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment