Commit efc74f2d authored by Yuxin Wu's avatar Yuxin Wu

refactor trainer

parent d3167ba3
...@@ -30,8 +30,11 @@ class Trainer(object): ...@@ -30,8 +30,11 @@ class Trainer(object):
Available Attritbutes: Available Attritbutes:
stat_holder: a `StatHolder` instance stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter` summary_writer: a `tf.SummaryWriter`
summary_op: a `tf.Operation` which returns summary string
config: a `TrainConfig` config: a `TrainConfig`
model: a `ModelDesc` model: a `ModelDesc`
sess: a `tf.Session`
coord: a `tf.train.Coordinator`
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
...@@ -15,16 +15,42 @@ from ..tfutils import (backup_collection, restore_collection, ...@@ -15,16 +15,42 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import QueueInputTrainer from .trainer import FeedlessTrainer
from .queue import QueueInputTrainer
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(QueueInputTrainer): class MultiGPUTrainer(FeedlessTrainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(len(self.config.tower)))
grad_list = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(self.config.tower):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_input_tensors_noreuse()
self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 might be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if idx == 0:
add_moving_summary(cost_var)
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower) super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
self.dequed_inputs = []
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -48,53 +74,18 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -48,53 +74,18 @@ class MultiGPUTrainer(QueueInputTrainer):
ret.append((grad, v)) ret.append((grad, v))
return ret return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(len(self.config.tower)))
grad_list = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(self.config.tower):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_dequeued_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 might be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if idx == 0:
add_moving_summary(cost_var)
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def _setup(self): def _setup(self):
self._build_enque_thread()
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = MultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads,
self.model.get_gradient_processor()) self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
class AsyncMultiGPUTrainer(MultiGPUTrainer): class AsyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer):
def _setup(self): def _setup(self):
self._build_enque_thread()
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: queue.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import threading
import tensorflow as tf
from ..dataflow.common import RepeatedData
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils import get_global_step_var, TowerContext
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors
from .trainer import FeedlessTrainer, MultiPredictorTowerTrainer
__all__ = ['QueueInputTrainerBase', 'QueueInputTrainer']
class EnqueueThread(threading.Thread):
def __init__(self, trainer):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.daemon = True
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1)
self.input_vars = trainer.input_vars
self.queue = trainer.input_queue
self.op = self.queue.enqueue(self.input_vars)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
self.coord.request_stop()
logger.info("Enqueue Thread Exited.")
class QueueInputTrainerBase(FeedlessTrainer):
def _build_enque_thread(self, input_queue):
""" create a thread that keeps filling the queue """
self.input_vars = self.model.get_input_vars()
if input_queue is None:
self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
input_th = EnqueueThread(self)
self.config.callbacks.append(StartProcOrThread(input_th))
def _get_input_tensors_noreuse(self):
""" Dequeue a datapoint from input_queue and return.
Can be called multiple times.
"""
ret = self.input_queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret
class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase):
""" Single GPU Trainer, takes input from a queue"""
def __init__(self, config, input_queue=None, predict_tower=None):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
super(QueueInputTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
def _single_tower_grad(self, actual_inputs):
""" Get grad and cost for single-tower"""
with TowerContext(''):
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE
add_moving_summary(cost_var)
return grads
def _setup(self):
assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
actual_inputs = self._get_input_tensors_noreuse()
grads = self._single_tower_grad(actual_inputs)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import threading
import time import time
from six.moves import zip from six.moves import zip
...@@ -16,10 +15,9 @@ from ..tfutils import (get_vars_by_names, freeze_collection, ...@@ -16,10 +15,9 @@ from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'FeedlessTrainer', 'MultiPredictorTowerTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
...@@ -55,6 +53,7 @@ class PredictorFactory(object): ...@@ -55,6 +53,7 @@ class PredictorFactory(object):
self.tower_built = True self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
""" A naive demo trainer """
def __init__(self, config): def __init__(self, config):
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
...@@ -94,134 +93,26 @@ class SimpleTrainer(Trainer): ...@@ -94,134 +93,26 @@ class SimpleTrainer(Trainer):
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0) return self._predictor_factory.get_predictor(input_names, output_names, 0)
class EnqueueThread(threading.Thread): class MultiPredictorTowerTrainer(Trainer):
def __init__(self, trainer): """ A trainer with possibly multiple prediction tower """
super(EnqueueThread, self).__init__() def _setup_predictor_factory(self, predict_tower):
self.name = 'EnqueueThread'
self.daemon = True
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1)
self.input_vars = trainer.input_vars
self.queue = trainer.input_queue
self.op = self.queue.enqueue(self.input_vars)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
self.coord.request_stop()
logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer):
""" Single GPU Trainer, takes input from a queue"""
def __init__(self, config, input_queue=None, predict_tower=None):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars()
if input_queue is None:
self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
predict_tower = predict_tower or [0] predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower) self.sess, self.model, predict_tower)
self.dequed_inputs = None
def _get_dequeued_inputs(self):
""" Dequeue a datapoint from input_queue and return"""
ret = self.input_queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape())
return ret
def _single_tower_grad(self):
""" Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_dequeued_inputs()
# test the overhead of queue
#with tf.device('/gpu:0'):
#self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
with TowerContext(''):
self.model.build_graph(self.dequed_inputs)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE
add_moving_summary(cost_var)
return grads
def _build_enque_thread(self):
""" create a thread that keeps filling the queue """
self.input_th = EnqueueThread(self)
self.config.callbacks.append(StartProcOrThread(self.input_th))
def _setup(self): def get_predict_func(self, input_names, output_names, tower=0):
assert len(self.config.tower) == 1, \ """
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." :param tower: return the kth predict_func
self._build_enque_thread() :returns: an `OnlinePredictor`
"""
grads = self._single_tower_grad() return self._predictor_factory.get_predictor(input_names, output_names, tower)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
def run_step(self): def get_predict_funcs(self, input_names, output_names, n):
""" Simply run self.train_op""" return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
self.sess.run(self.train_op)
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
class FeedlessTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch # need to run summary_op every epoch
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
...@@ -229,12 +120,7 @@ class QueueInputTrainer(Trainer): ...@@ -229,12 +120,7 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names, tower=0): def _get_input_tensors_noreuse(self):
""" return a list of actual input tensors.
Always return new tensors (for multi tower) if called mutliple times.
""" """
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
"""
return self._predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment