Commit efe3dfb5 authored by Yuxin Wu's avatar Yuxin Wu

remove the many levels of Trainer herarchy

parent 079eb3a9
......@@ -8,7 +8,6 @@ import os
from six.moves import range
from ..utils import logger
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
......@@ -35,7 +34,7 @@ class OverrideToLocalVariable(object):
return getter(name, *args, **kwargs)
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
class DistributedReplicatedTrainer(MultiGPUTrainerBase):
"""
Distributed replicated training.
Each worker process builds the same model on one or more GPUs.
......@@ -191,7 +190,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
devices=self.raw_devices,
var_strategy='replicated',
vs_names=None) # use the default vs names
......
......@@ -20,10 +20,8 @@ class FeedfreeTrainerBase(Trainer):
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
# TODO deprecated
@deprecated("Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)")
def build_train_tower(self):
logger.warn("build_train_tower() was deprecated! Please build the graph "
"yourself, e.g. by self.model.build_graph(self._input_source)")
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
......@@ -36,16 +34,20 @@ class FeedfreeTrainerBase(Trainer):
self.hooked_sess.run(self.train_op)
# TODO Kept for now for back-compat
# deprecated
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def __init__(self, *args, **kwargs):
super(SingleCostFeedfreeTrainer, self).__init__(*args, **kwargs)
logger.warn("SingleCostFeedfreeTrainer was deprecated!")
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.model.build_graph(self._input_source)
return self.model.get_cost_and_grad()
@deprecated("Use SimpleTrainer with config.data instead!")
@deprecated("Use SimpleTrainer with config.data is the same!", "2017-09-13")
def SimpleFeedfreeTrainer(config):
assert isinstance(config.data, FeedfreeInput), config.data
return SimpleTrainer(config)
......@@ -53,9 +55,8 @@ def SimpleFeedfreeTrainer(config):
def QueueInputTrainer(config, input_queue=None):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
......
......@@ -15,9 +15,8 @@ from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient
from ..callbacks.graph import RunOp
from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .feedfree import FeedfreeTrainerBase
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
......@@ -44,7 +43,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
config.data = StagingInputWrapper(config.data, devices)
class MultiGPUTrainerBase(Trainer):
class MultiGPUTrainerBase(FeedfreeTrainerBase):
""" Base class for multi-gpu training"""
@staticmethod
def build_on_multi_tower(
......@@ -116,6 +115,11 @@ class MultiGPUTrainerBase(Trainer):
nvars = [len(k) for k in grad_list]
assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars)
@staticmethod
def _build_graph_get_grads(model, input):
model.build_graph(input)
return model.get_cost_and_grad()[1]
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
......@@ -148,7 +152,7 @@ class LeastLoadedDeviceSetter(object):
return sanitize_name(device_name)
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
"""
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored across all
......@@ -199,7 +203,9 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
# debug tower performance (without update):
......@@ -223,7 +229,7 @@ def SyncMultiGPUTrainer(config):
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
"""
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced.
......@@ -266,7 +272,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1))
......@@ -308,7 +315,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without averaging the gradient.
......@@ -330,7 +337,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment