Commit 1db8e2b4 authored by Yuxin Wu's avatar Yuxin Wu

more on trainers

parent efc74f2d
......@@ -12,7 +12,7 @@ import six
from ..utils import logger
from .common import get_op_var_name
from .varmanip import SessionUpdate, get_savename_from_varname, is_training_specific_name
from .varmanip import SessionUpdate, get_savename_from_varname, is_training_name
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
......@@ -129,7 +129,7 @@ class SaverRestore(SessionInit):
var_dict[name].append(v)
chkpt_vars_used.add(name)
else:
if not is_training_specific_name(v.op.name):
if not is_training_name(v.op.name):
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
......
......@@ -101,19 +101,21 @@ def add_moving_summary(v, *args):
assert x.get_shape().ndims == 0
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
def summary_moving_average():
""" Create a MovingAverage op and summary for all variables in MOVING_SUMMARY_VARS_KEY.
def summary_moving_average(tensors=None):
"""
Create a MovingAverage op and summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average.
"""
if tensors is None:
tensors = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
with tf.name_scope('EMA_summary'):
# TODO will produce EMA_summary/tower0/xxx. not elegant
global_step_var = get_global_step_var()
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='EMA')
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary):
0.99, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c))
return avg_maintain_op
......
......@@ -13,7 +13,7 @@ from ..utils.naming import *
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_specific_name']
'get_savename_from_varname', 'is_training_name']
def get_savename_from_varname(
varname, varname_prefix=None,
......@@ -97,7 +97,7 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n)
return result
def is_training_specific_name(name):
def is_training_name(name):
"""
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
......
......@@ -139,7 +139,7 @@ class Trainer(object):
if self.coord.should_stop():
return
self.run_step() # implemented by subclass
#callbacks.trigger_step() # not useful?
callbacks.trigger_step() # not useful?
self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
......
......@@ -15,30 +15,26 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedlessTrainer
from .queue import QueueInputTrainer
from .trainer import FeedlessTrainer, SingleCostFeedlessTrainer
from .queue import QueueInputTrainer, QueueInputTrainerBase
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(FeedlessTrainer):
""" Base class for multi-gpu training"""
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(len(self.config.tower)))
@staticmethod
def _multi_tower_grads(towers, get_tower_grad_func):
logger.info("Training a model of {} tower".format(len(towers)))
grad_list = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(self.config.tower):
for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_input_tensors_noreuse()
self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 might be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
grad_list.append(get_tower_grad_func())
if idx == 0:
add_moving_summary(cost_var)
......@@ -47,10 +43,12 @@ class MultiGPUTrainer(FeedlessTrainer):
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer):
class SyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
@staticmethod
def _average_grads(tower_grads):
......@@ -75,18 +73,28 @@ class SyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer):
return ret
def _setup(self):
grad_list = self._multi_tower_grads()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
class AsyncMultiGPUTrainer(QueueInputTrainer, MultiGPUTrainer):
def run_step(self):
self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(QueueInputTrainerBase, MultiGPUTrainer, SingleCostFeedlessTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
def _setup(self):
grad_list = self._multi_tower_grads()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
......
......@@ -13,7 +13,8 @@ from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors
from .trainer import FeedlessTrainer, MultiPredictorTowerTrainer
from .trainer import (FeedlessTrainer, MultiPredictorTowerTrainer,
SingleCostFeedlessTrainer)
__all__ = ['QueueInputTrainerBase', 'QueueInputTrainer']
......@@ -88,7 +89,7 @@ class QueueInputTrainerBase(FeedlessTrainer):
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret
class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase):
class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, SingleCostFeedlessTrainer):
""" Single GPU Trainer, takes input from a queue"""
def __init__(self, config, input_queue=None, predict_tower=None):
......@@ -103,23 +104,12 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase):
self._setup_predictor_factory(predict_tower)
self._build_enque_thread(input_queue)
def _single_tower_grad(self, actual_inputs):
""" Get grad and cost for single-tower"""
with TowerContext(''):
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE
add_moving_summary(cost_var)
return grads
def _setup(self):
assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
actual_inputs = self._get_input_tensors_noreuse()
grads = self._single_tower_grad(actual_inputs)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
with TowerContext(''):
cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......
......@@ -17,7 +17,8 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
__all__ = ['SimpleTrainer', 'FeedlessTrainer', 'MultiPredictorTowerTrainer']
__all__ = ['SimpleTrainer', 'FeedlessTrainer', 'MultiPredictorTowerTrainer',
'SingleCostFeedlessTrainer']
class PredictorFactory(object):
""" Make predictors for a trainer"""
......@@ -124,3 +125,16 @@ class FeedlessTrainer(Trainer):
""" return a list of actual input tensors.
Always return new tensors (for multi tower) if called mutliple times.
"""
class SingleCostFeedlessTrainer(Trainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors_noreuse()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
add_moving_summary(cost_var)
return cost_var, grads
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment