Commit efe3dfb5 authored by Yuxin Wu's avatar Yuxin Wu

remove the many levels of Trainer herarchy

parent 079eb3a9
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
from six.moves import range from six.moves import range
from ..utils import logger from ..utils import logger
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
...@@ -35,7 +34,7 @@ class OverrideToLocalVariable(object): ...@@ -35,7 +34,7 @@ class OverrideToLocalVariable(object):
return getter(name, *args, **kwargs) return getter(name, *args, **kwargs)
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): class DistributedReplicatedTrainer(MultiGPUTrainerBase):
""" """
Distributed replicated training. Distributed replicated training.
Each worker process builds the same model on one or more GPUs. Each worker process builds the same model on one or more GPUs.
...@@ -191,7 +190,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -191,7 +190,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
# Ngpu * Nvar * 2 # Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
lambda: self._get_cost_and_grad()[1], lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
devices=self.raw_devices, devices=self.raw_devices,
var_strategy='replicated', var_strategy='replicated',
vs_names=None) # use the default vs names vs_names=None) # use the default vs names
......
...@@ -20,10 +20,8 @@ class FeedfreeTrainerBase(Trainer): ...@@ -20,10 +20,8 @@ class FeedfreeTrainerBase(Trainer):
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``self.data`` to be a :class:`FeedfreeInput`.
""" """
# TODO deprecated @deprecated("Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)")
def build_train_tower(self): def build_train_tower(self):
logger.warn("build_train_tower() was deprecated! Please build the graph "
"yourself, e.g. by self.model.build_graph(self._input_source)")
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
...@@ -36,16 +34,20 @@ class FeedfreeTrainerBase(Trainer): ...@@ -36,16 +34,20 @@ class FeedfreeTrainerBase(Trainer):
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
# TODO Kept for now for back-compat # deprecated
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """ """ A feedfree Trainer which assumes a single cost. """
def __init__(self, *args, **kwargs):
super(SingleCostFeedfreeTrainer, self).__init__(*args, **kwargs)
logger.warn("SingleCostFeedfreeTrainer was deprecated!")
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient""" """ get the cost and gradient"""
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
return self.model.get_cost_and_grad() return self.model.get_cost_and_grad()
@deprecated("Use SimpleTrainer with config.data instead!") @deprecated("Use SimpleTrainer with config.data is the same!", "2017-09-13")
def SimpleFeedfreeTrainer(config): def SimpleFeedfreeTrainer(config):
assert isinstance(config.data, FeedfreeInput), config.data assert isinstance(config.data, FeedfreeInput), config.data
return SimpleTrainer(config) return SimpleTrainer(config)
...@@ -53,9 +55,8 @@ def SimpleFeedfreeTrainer(config): ...@@ -53,9 +55,8 @@ def SimpleFeedfreeTrainer(config):
def QueueInputTrainer(config, input_queue=None): def QueueInputTrainer(config, input_queue=None):
""" """
A wrapper trainer which automatically wraps ``config.dataflow`` by a A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
:class:`QueueInput`. It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args: Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist. config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
......
...@@ -15,9 +15,8 @@ from ..tfutils.collection import backup_collection, restore_collection ...@@ -15,9 +15,8 @@ from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .feedfree import FeedfreeTrainerBase
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter', 'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
...@@ -44,7 +43,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True): ...@@ -44,7 +43,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
config.data = StagingInputWrapper(config.data, devices) config.data = StagingInputWrapper(config.data, devices)
class MultiGPUTrainerBase(Trainer): class MultiGPUTrainerBase(FeedfreeTrainerBase):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def build_on_multi_tower( def build_on_multi_tower(
...@@ -116,6 +115,11 @@ class MultiGPUTrainerBase(Trainer): ...@@ -116,6 +115,11 @@ class MultiGPUTrainerBase(Trainer):
nvars = [len(k) for k in grad_list] nvars = [len(k) for k in grad_list]
assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars) assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars)
@staticmethod
def _build_graph_get_grads(model, input):
model.build_graph(input)
return model.get_cost_and_grad()[1]
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object): class LeastLoadedDeviceSetter(object):
...@@ -148,7 +152,7 @@ class LeastLoadedDeviceSetter(object): ...@@ -148,7 +152,7 @@ class LeastLoadedDeviceSetter(object):
return sanitize_name(device_name) return sanitize_name(device_name)
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
""" """
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored across all from each tower, averages them and update to variables stored across all
...@@ -199,7 +203,9 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -199,7 +203,9 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices] worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices) self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), devices)
MultiGPUTrainerBase._check_grad_list(grad_list) MultiGPUTrainerBase._check_grad_list(grad_list)
# debug tower performance (without update): # debug tower performance (without update):
...@@ -223,7 +229,7 @@ def SyncMultiGPUTrainer(config): ...@@ -223,7 +229,7 @@ def SyncMultiGPUTrainer(config):
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu') return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
""" """
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced. whole model. Each gradient update is broadcast and synced.
...@@ -266,7 +272,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -266,7 +272,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
lambda: self._get_cost_and_grad()[1], lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
var_strategy='replicated', var_strategy='replicated',
# use no variable scope for the first tower # use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1)) vs_names=[''] + [None] * (self.config.nr_tower - 1))
...@@ -308,7 +315,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -308,7 +315,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
return tf.group(*post_init_ops, name='sync_variables_from_tower0') return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
""" """
A multi-tower multi-GPU trainer where each tower independently A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without averaging the gradient. asynchronously updates the model without averaging the gradient.
...@@ -330,7 +337,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): ...@@ -330,7 +337,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower] raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices) self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), devices)
MultiGPUTrainerBase._check_grad_list(grad_list) MultiGPUTrainerBase._check_grad_list(grad_list)
if self._scale_gradient and self.config.nr_tower > 1: if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment