Commit 1db8e2b4 authored by Yuxin Wu's avatar Yuxin Wu

more on trainers

parent efc74f2d
...@@ -12,7 +12,7 @@ import six ...@@ -12,7 +12,7 @@ import six
from ..utils import logger from ..utils import logger
from .common import get_op_var_name from .common import get_op_var_name
from .varmanip import SessionUpdate, get_savename_from_varname, is_training_specific_name from .varmanip import SessionUpdate, get_savename_from_varname, is_training_name
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
...@@ -129,7 +129,7 @@ class SaverRestore(SessionInit): ...@@ -129,7 +129,7 @@ class SaverRestore(SessionInit):
var_dict[name].append(v) var_dict[name].append(v)
chkpt_vars_used.add(name) chkpt_vars_used.add(name)
else: else:
if not is_training_specific_name(v.op.name): if not is_training_name(v.op.name):
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name)) logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available): if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used unused = vars_available - chkpt_vars_used
......
...@@ -101,19 +101,21 @@ def add_moving_summary(v, *args): ...@@ -101,19 +101,21 @@ def add_moving_summary(v, *args):
assert x.get_shape().ndims == 0 assert x.get_shape().ndims == 0
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
def summary_moving_average(): def summary_moving_average(tensors=None):
""" Create a MovingAverage op and summary for all variables in MOVING_SUMMARY_VARS_KEY.
:returns: a op to maintain these average.
""" """
Create a MovingAverage op and summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average.
"""
if tensors is None:
tensors = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
with tf.name_scope('EMA_summary'): with tf.name_scope('EMA_summary'):
# TODO will produce EMA_summary/tower0/xxx. not elegant # TODO will produce EMA_summary/tower0/xxx. not elegant
global_step_var = get_global_step_var()
with tf.name_scope(None): with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='EMA') 0.99, num_updates=get_global_step_var(), name='EMA')
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY) avg_maintain_op = averager.apply(tensors)
avg_maintain_op = averager.apply(vars_to_summary) for idx, c in enumerate(tensors):
for idx, c in enumerate(vars_to_summary):
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c)) tf.scalar_summary(name, averager.average(c))
return avg_maintain_op return avg_maintain_op
......
...@@ -13,7 +13,7 @@ from ..utils.naming import * ...@@ -13,7 +13,7 @@ from ..utils.naming import *
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_specific_name'] 'get_savename_from_varname', 'is_training_name']
def get_savename_from_varname( def get_savename_from_varname(
varname, varname_prefix=None, varname, varname_prefix=None,
...@@ -97,7 +97,7 @@ def dump_chkpt_vars(model_path): ...@@ -97,7 +97,7 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n) result[n] = reader.get_tensor(n)
return result return result
def is_training_specific_name(name): def is_training_name(name):
""" """
This is only used to improve logging. This is only used to improve logging.
:returns: guess whether this tensor is something only used in training. :returns: guess whether this tensor is something only used in training.
......
...@@ -139,7 +139,7 @@ class Trainer(object): ...@@ -139,7 +139,7 @@ class Trainer(object):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
#callbacks.trigger_step() # not useful? callbacks.trigger_step() # not useful?
self.trigger_epoch() self.trigger_epoch()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
......
...@@ -15,30 +15,26 @@ from ..tfutils import (backup_collection, restore_collection, ...@@ -15,30 +15,26 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedlessTrainer from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer
from .queue import QueueInputTrainer from .queue import QueueInputTrainer, QueueInputTrainerBase
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(FeedlessTrainer): class MultiGPUTrainer(FeedlessTrainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
def _multi_tower_grads(self): @staticmethod
logger.info("Training a model of {} tower".format(len(self.config.tower))) def _multi_tower_grads(towers, get_tower_grad_func):
logger.info("Training a model of {} tower".format(len(towers)))
grad_list = [] grad_list = []
global_scope = tf.get_variable_scope() global_scope = tf.get_variable_scope()
for idx, t in enumerate(self.config.tower): for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \ with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \ tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope: TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_input_tensors_noreuse()
self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 might be faster? grad_list.append(get_tower_grad_func())
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if idx == 0: if idx == 0:
add_moving_summary(cost_var) add_moving_summary(cost_var)
...@@ -47,10 +43,12 @@ class MultiGPUTrainer(FeedlessTrainer): ...@@ -47,10 +43,12 @@ class MultiGPUTrainer(FeedlessTrainer):
restore_collection(backup) restore_collection(backup)
return grad_list return grad_list
class SyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer): class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -75,18 +73,28 @@ class SyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer): ...@@ -75,18 +73,28 @@ class SyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer):
return ret return ret
def _setup(self): def _setup(self):
grad_list = self._multi_tower_grads() grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
class AsyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer): def run_step(self):
self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
def _setup(self): def _setup(self):
grad_list = self._multi_tower_grads() grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
......
...@@ -13,7 +13,8 @@ from ..utils import logger ...@@ -13,7 +13,8 @@ from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .trainer import FeedlessTrainer, MultiPredictorTowerTrainer from .trainer import (FeedlessTrainer, MultiPredictorTowerTrainer,
SingleCostFeedlessTrainer)
__all__ = ['QueueInputTrainerBase', 'QueueInputTrainer'] __all__ = ['QueueInputTrainerBase', 'QueueInputTrainer']
...@@ -88,7 +89,7 @@ class QueueInputTrainerBase(FeedlessTrainer): ...@@ -88,7 +89,7 @@ class QueueInputTrainerBase(FeedlessTrainer):
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)] #tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret return ret
class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase): class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, SingleCostFeedlessTrainer):
""" Single GPU Trainer, takes input from a queue""" """ Single GPU Trainer, takes input from a queue"""
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
...@@ -103,23 +104,12 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase): ...@@ -103,23 +104,12 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase):
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue) self._build_enque_thread(input_queue)
def _single_tower_grad(self, actual_inputs):
""" Get grad and cost for single-tower"""
with TowerContext(''):
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE
add_moving_summary(cost_var)
return grads
def _setup(self): def _setup(self):
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
actual_inputs = self._get_input_tensors_noreuse() with TowerContext(''):
grads = self._single_tower_grad(actual_inputs) cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......
...@@ -17,7 +17,8 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary ...@@ -17,7 +17,8 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
__all__ = ['SimpleTrainer', 'FeedlessTrainer', 'MultiPredictorTowerTrainer'] __all__ = ['SimpleTrainer', 'FeedlessTrainer', 'MultiPredictorTowerTrainer',
'SingleCostFeedlessTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
...@@ -124,3 +125,16 @@ class FeedlessTrainer(Trainer): ...@@ -124,3 +125,16 @@ class FeedlessTrainer(Trainer):
""" return a list of actual input tensors. """ return a list of actual input tensors.
Always return new tensors (for multi tower) if called mutliple times. Always return new tensors (for multi tower) if called mutliple times.
""" """
class SingleCostFeedlessTrainer(Trainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors_noreuse()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
add_moving_summary(cost_var)
return cost_var, grads
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment