Commit c279dbfe authored by Yuxin Wu's avatar Yuxin Wu

use _setup and small other refactors

parent 82b418fd
......@@ -12,7 +12,23 @@ from .symbolic_functions import rms
from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient']
'ScaleGradient', 'MapGradient', 'apply_grad_processors']
def apply_grad_processors(grads, gradprocs):
"""
:param grads: list of (grad, var).
:param gradprocs: list of `GradientProcessor` instances.
:returns: list of (grad, var) went through the processors
"""
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
for proc in gradprocs:
g = proc.process(g)
return g
class GradientProcessor(object):
__metaclass__ = ABCMeta
......@@ -98,12 +114,14 @@ class CheckGradient(MapGradient):
class ScaleGradient(MapGradient):
"""
Scale gradient by a multiplier
Scale certain gradient by a multiplier
"""
def __init__(self, multipliers):
"""
:param multipliers: list of (regex, float)
"""
if not isinstance(multipliers, list):
multipliers = [multipliers]
self.multipliers = multipliers
super(ScaleGradient, self).__init__(self._mapper)
......
......@@ -98,11 +98,11 @@ def add_moving_summary(v, *args):
v = [v]
v.extend(args)
for x in v:
assert x.get_shape().ndims == 0
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
def summary_moving_average():
""" Create a MovingAverage op and summary for all variables in
MOVING_SUMMARY_VARS_KEY.
""" Create a MovingAverage op and summary for all variables in MOVING_SUMMARY_VARS_KEY.
:returns: a op to maintain these average.
"""
with tf.name_scope('EMA_summary'):
......@@ -113,7 +113,6 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary):
# TODO assert scalar
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c))
return avg_maintain_op
......
......@@ -45,10 +45,10 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
@abstractmethod
def train(self):
""" Start training"""
pass
self.setup()
self.main_loop()
@abstractmethod
def run_step(self):
......@@ -92,7 +92,8 @@ class Trainer(object):
create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val)
def finalize(self):
def setup(self):
self._setup()
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
......@@ -112,8 +113,11 @@ class Trainer(object):
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
@abstractmethod
def _setup(self):
""" setup Trainer-specific stuff for training"""
def main_loop(self):
self.finalize()
callbacks = self.config.callbacks
with self.sess.as_default():
try:
......@@ -139,16 +143,3 @@ class Trainer(object):
self.coord.request_stop()
self.summary_writer.close()
self.sess.close()
def process_grads(self, grads):
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
procs = self.config.model.get_gradient_processor()
for proc in procs:
g = proc.process(g)
return g
......@@ -14,6 +14,7 @@ from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors
from .trainer import QueueInputTrainer
......@@ -32,11 +33,16 @@ class MultiGPUTrainer(QueueInputTrainer):
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
for x in grad_and_vars:
assert x[0] is not None, \
"Gradient w.r.t {} is None!".format(v.name)
all_grad = [k[0] for k in grad_and_vars]
nones = list(set(all_grad))
if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
continue
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
grad = tf.add_n(all_grad) / float(len(tower_grads))
except:
logger.error("Error while processing gradients of {}".format(v.name))
raise
......@@ -44,8 +50,7 @@ class MultiGPUTrainer(QueueInputTrainer):
return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(
len(self.config.tower)))
logger.info("Training a model of {} tower".format(len(self.config.tower)))
grad_list = []
global_scope = tf.get_variable_scope()
......@@ -60,59 +65,54 @@ class MultiGPUTrainer(QueueInputTrainer):
self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster?
# TODO gate_gradienst=0 might be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if idx == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
add_moving_summary(cost_var)
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
def _setup(self):
self._build_enque_thread()
grad_list = self._multi_tower_grads()
grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
describe_model()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
def _setup(self):
self._build_enque_thread()
grad_list = self._multi_tower_grads()
gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
with tf.name_scope('AsyncScaleGrad'):
return [(grad / len(self.config.tower) if grad is not None else None, var)
for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower)))
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grad_list[0], get_global_step_var()),
self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var()),
summary_moving_average(), name='train_op')
describe_model()
self._start_async_threads(grad_list)
self.main_loop()
def _start_async_threads(self, grad_list):
# prepare train_op for the rest of the towers
# itertools.count is atomic w.r.t. python threads
......@@ -145,7 +145,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary(
'async_global_step', async_step_total_cnt)
'async-global-step', async_step_total_cnt)
except:
pass
logger.exception("Cannot log async-global-step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
......@@ -18,6 +18,7 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
......@@ -51,37 +52,39 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
build_multi_tower_prediction_graph(
self.model, self.towers)
build_multi_tower_prediction_graph(self.model, self.towers)
self.tower_built = True
class SimpleTrainer(Trainer):
def __init__(self, config):
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
def run_step(self):
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self):
def _setup(self):
model = self.model
self.input_vars = model.get_input_vars()
with TowerContext(''):
model.build_graph(self.input_vars)
cost_var = model.get_cost() # TODO assert scalar
cost_var = model.get_cost()
add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
avg_maintain_op = summary_moving_average()
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
summary_moving_average())
describe_model()
# create an infinte data producer
self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop()
def _trigger_epoch(self):
if self.summary_op is not None:
......@@ -91,14 +94,14 @@ class SimpleTrainer(Trainer):
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
if not hasattr(self, 'predictor_factory'):
self.predictor_factory = PredictorFactory(self.sess, self.model, [0])
return self.predictor_factory.get_predictor(input_names, output_names, 0)
return self._predictor_factory.get_predictor(input_names, output_names, 0)
class EnqueueThread(threading.Thread):
def __init__(self, trainer):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.daemon = True
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1)
......@@ -109,7 +112,8 @@ class EnqueueThread(threading.Thread):
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
self.daemon = True
add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
self.dataflow.reset_state()
......@@ -155,7 +159,9 @@ class QueueInputTrainer(Trainer):
self.input_queue = input_queue
# by default, use the first training gpu for prediction
self.predict_tower = predict_tower or [0]
predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower)
self.dequed_inputs = None
def _get_dequeued_inputs(self):
......@@ -171,8 +177,6 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self):
""" Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_dequeued_inputs()
add_moving_summary(tf.cast(
self.input_queue.size(), tf.float32, name='input-queue-size'))
# test the overhead of queue
#with tf.device('/gpu:0'):
......@@ -192,24 +196,22 @@ class QueueInputTrainer(Trainer):
self.input_th = EnqueueThread(self)
self.config.callbacks.append(StartProcOrThread(self.input_th))
def train(self):
def _setup(self):
assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self._build_enque_thread()
grads = self._single_tower_grad()
grads = self.process_grads(grads)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
describe_model()
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
self.main_loop()
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
......@@ -236,10 +238,7 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
"""
if not hasattr(self, 'predictor_factory'):
self.predictor_factory = PredictorFactory(
self.sess, self.model, self.predict_tower)
return self.predictor_factory.get_predictor(input_names, output_names, tower)
return self._predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment