Commit 9952c6c6 authored by Yuxin Wu's avatar Yuxin Wu

add trainer with shared mainloop

parent de37446f
...@@ -9,6 +9,7 @@ import copy ...@@ -9,6 +9,7 @@ import copy
import argparse import argparse
import re import re
import tqdm import tqdm
from abc import ABCMeta
from .models import ModelDesc from .models import ModelDesc
from .dataflow.common import RepeatedData from .dataflow.common import RepeatedData
...@@ -63,14 +64,6 @@ class TrainConfig(object): ...@@ -63,14 +64,6 @@ class TrainConfig(object):
self.nr_tower = int(kwargs.pop('nr_tower', 1)) self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
ret.append((grad, v))
return ret
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
if grad: if grad:
...@@ -95,103 +88,182 @@ def scale_grads(grads, multiplier): ...@@ -95,103 +88,182 @@ def scale_grads(grads, multiplier):
ret.append((grad, var)) ret.append((grad, var))
return ret return ret
def start_train(config): class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
"""
Config: a `TrainConfig` instance
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.sess.close()
def init_session_and_coord(self):
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer):
def run_step(self):
try:
data = next(self.data_producer)
except StopIteration:
self.data_producer = self.config.dataset.get_data()
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
def train(self):
model = self.config.model
input_vars = model.get_input_vars()
self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True)
avg_maintain_op = summary_moving_average(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
self.init_session_and_coord()
self.data_producer = self.config.dataset.get_data()
self.main_loop()
class QueueInputTrainer(Trainer):
""" """
Start training with a config Trainer which builds a queue for input.
Args: Support multi GPU.
config: a TrainConfig instance
""" """
model = config.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
callbacks = config.callbacks
tf.add_to_collection(MODEL_KEY, model)
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs():
model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor): # only one input
model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape())
return model_inputs
# get gradients to update:
if config.nr_tower > 1:
logger.info("Training a model of {} tower".format(config.nr_tower))
# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
grad_list = []
for i in range(config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grad_list.append(
config.optimizer.compute_gradients(cost_var))
if i == 0:
tf.get_variable_scope().reuse_variables()
for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in coll_keys:
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = average_grads(grad_list)
else:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
train_op = tf.group(
config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
# start training:
coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True)
# create a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, config.dataset, input_queue)
input_th.start()
with sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + config.step_per_epoch)):
for step in tqdm.trange(
config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if coord.should_stop():
return
sess.run([train_op]) # faster since train_op return None
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
sess.close()
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
ret.append((grad, v))
return ret
def train(self):
model = self.config.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs():
model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor): # only one input
model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape())
return model_inputs
# get gradients to update:
if self.config.nr_tower > 1:
logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grad_list.append(
self.config.optimizer.compute_gradients(cost_var))
if i == 0:
tf.get_variable_scope().reuse_variables()
for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in coll_keys:
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list)
else:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grads = self.config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model()
self.init_session_and_coord()
# create a thread that keeps filling the queue
input_th = EnqueueThread(self.sess, self.coord, enqueue_op, self.config.dataset, input_queue)
input_th.start()
self.main_loop()
def run_step(self):
self.sess.run([self.train_op]) # faster since train_op return None
def start_train(config):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
#tr = SimpleTrainer(config)
tr = QueueInputTrainer(config)
tr.train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment