Commit 9952c6c6 authored by Yuxin Wu's avatar Yuxin Wu

add trainer with shared mainloop

parent de37446f
......@@ -9,6 +9,7 @@ import copy
import argparse
import re
import tqdm
from abc import ABCMeta
from .models import ModelDesc
from .dataflow.common import RepeatedData
......@@ -63,14 +64,6 @@ class TrainConfig(object):
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
ret.append((grad, v))
return ret
def summary_grads(grads):
for grad, var in grads:
if grad:
......@@ -95,17 +88,117 @@ def scale_grads(grads, multiplier):
ret.append((grad, var))
return ret
def start_train(config):
class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
"""
Start training with a config
Args:
config: a TrainConfig instance
Config: a `TrainConfig` instance
"""
model = config.model
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.sess.close()
def init_session_and_coord(self):
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer):
def run_step(self):
try:
data = next(self.data_producer)
except StopIteration:
self.data_producer = self.config.dataset.get_data()
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self):
model = self.config.model
input_vars = model.get_input_vars()
self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True)
avg_maintain_op = summary_moving_average(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
self.init_session_and_coord()
self.data_producer = self.config.dataset.get_data()
self.main_loop()
class QueueInputTrainer(Trainer):
"""
Trainer which builds a queue for input.
Support multi GPU.
"""
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
ret.append((grad, v))
return ret
def train(self):
model = self.config.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
callbacks = config.callbacks
tf.add_to_collection(MODEL_KEY, model)
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs():
......@@ -117,19 +210,19 @@ def start_train(config):
return model_inputs
# get gradients to update:
if config.nr_tower > 1:
logger.info("Training a model of {} tower".format(config.nr_tower))
if self.config.nr_tower > 1:
logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
grad_list = []
for i in range(config.nr_tower):
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grad_list.append(
config.optimizer.compute_gradients(cost_var))
self.config.optimizer.compute_gradients(cost_var))
if i == 0:
tf.get_variable_scope().reuse_variables()
......@@ -138,60 +231,39 @@ def start_train(config):
for k in coll_keys:
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = average_grads(grad_list)
grads = QueueInputTrainer._average_grads(grad_list)
else:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
train_op = tf.group(
config.optimizer.apply_gradients(grads, get_global_step_var()),
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
# start training:
coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True)
self.init_session_and_coord()
# create a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue)
input_th = EnqueueThread(self.sess, self.coord, enqueue_op, self.config.dataset, input_queue)
input_th.start()
self.main_loop()
with sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
tf.get_default_graph().finalize()
def run_step(self):
self.sess.run([self.train_op]) # faster since train_op return None
for epoch in xrange(1, config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + config.step_per_epoch)):
for step in tqdm.trange(
config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if coord.should_stop():
return
sess.run([train_op]) # faster since train_op return None
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
sess.close()
def start_train(config):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
#tr = SimpleTrainer(config)
tr = QueueInputTrainer(config)
tr.train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment