Commit 9952c6c6 authored by Yuxin Wu's avatar Yuxin Wu

add trainer with shared mainloop

parent de37446f
...@@ -9,6 +9,7 @@ import copy ...@@ -9,6 +9,7 @@ import copy
import argparse import argparse
import re import re
import tqdm import tqdm
from abc import ABCMeta
from .models import ModelDesc from .models import ModelDesc
from .dataflow.common import RepeatedData from .dataflow.common import RepeatedData
...@@ -63,14 +64,6 @@ class TrainConfig(object): ...@@ -63,14 +64,6 @@ class TrainConfig(object):
self.nr_tower = int(kwargs.pop('nr_tower', 1)) self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
ret.append((grad, v))
return ret
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
if grad: if grad:
...@@ -95,17 +88,117 @@ def scale_grads(grads, multiplier): ...@@ -95,17 +88,117 @@ def scale_grads(grads, multiplier):
ret.append((grad, var)) ret.append((grad, var))
return ret return ret
def start_train(config): class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
""" """
Start training with a config Config: a `TrainConfig` instance
Args:
config: a TrainConfig instance
""" """
model = config.model assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.sess.close()
def init_session_and_coord(self):
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer):
def run_step(self):
try:
data = next(self.data_producer)
except StopIteration:
self.data_producer = self.config.dataset.get_data()
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self):
model = self.config.model
input_vars = model.get_input_vars()
self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True)
avg_maintain_op = summary_moving_average(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
self.init_session_and_coord()
self.data_producer = self.config.dataset.get_data()
self.main_loop()
class QueueInputTrainer(Trainer):
"""
Trainer which builds a queue for input.
Support multi GPU.
"""
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
ret.append((grad, v))
return ret
def train(self):
model = self.config.model
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
input_queue = model.get_input_queue() input_queue = model.get_input_queue()
callbacks = config.callbacks
tf.add_to_collection(MODEL_KEY, model)
enqueue_op = input_queue.enqueue(input_vars) enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs(): def get_model_inputs():
...@@ -117,19 +210,19 @@ def start_train(config): ...@@ -117,19 +210,19 @@ def start_train(config):
return model_inputs return model_inputs
# get gradients to update: # get gradients to update:
if config.nr_tower > 1: if self.config.nr_tower > 1:
logger.info("Training a model of {} tower".format(config.nr_tower)) logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device # to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {} kept_summaries = {}
grad_list = [] grad_list = []
for i in range(config.nr_tower): for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grad_list.append( grad_list.append(
config.optimizer.compute_gradients(cost_var)) self.config.optimizer.compute_gradients(cost_var))
if i == 0: if i == 0:
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
...@@ -138,60 +231,39 @@ def start_train(config): ...@@ -138,60 +231,39 @@ def start_train(config):
for k in coll_keys: for k in coll_keys:
del tf.get_collection(k)[:] del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection(k).extend(kept_summaries[k])
grads = average_grads(grad_list) grads = QueueInputTrainer._average_grads(grad_list)
else: else:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device? avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads) check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier()) grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads) summary_grads(grads)
train_op = tf.group( self.train_op = tf.group(
config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) avg_maintain_op)
describe_model() describe_model()
sess = tf.Session(config=config.session_config) self.init_session_and_coord()
config.session_init.init(sess)
# start training:
coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True)
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue) input_th = EnqueueThread(self.sess, self.coord, enqueue_op, self.config.dataset, input_queue)
input_th.start() input_th.start()
self.main_loop()
with sess.as_default(): def run_step(self):
try: self.sess.run([self.train_op]) # faster since train_op return None
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + config.step_per_epoch)):
for step in tqdm.trange(
config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if coord.should_stop():
return
sess.run([train_op]) # faster since train_op return None
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
sess.close()
def start_train(config):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
#tr = SimpleTrainer(config)
tr = QueueInputTrainer(config)
tr.train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment