Commit 206e1a67 authored by Yuxin Wu's avatar Yuxin Wu

some docs change

parent 843ab15c
...@@ -153,7 +153,7 @@ class QueueInput(FeedfreeInput): ...@@ -153,7 +153,7 @@ class QueueInput(FeedfreeInput):
def size(self): def size(self):
return self.ds.size() return self.ds.size()
# TODO XXX use input data mapping. not all placeholders are needed # TODO use input data mapping. not all placeholders are needed
def setup(self, model): def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
...@@ -335,7 +335,14 @@ class ZMQInput(FeedfreeInput): ...@@ -335,7 +335,14 @@ class ZMQInput(FeedfreeInput):
class StagingInputWrapper(FeedfreeInput): class StagingInputWrapper(FeedfreeInput):
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
"""
class StagingCallback(Callback): class StagingCallback(Callback):
"""
A callback registered by this input source, to make sure stage/unstage
is run at each step.
"""
def __init__(self, stage_op, unstage_op, nr_stage): def __init__(self, stage_op, unstage_op, nr_stage):
self.nr_stage = nr_stage self.nr_stage = nr_stage
self.stage_op = stage_op self.stage_op = stage_op
...@@ -351,6 +358,12 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -351,6 +358,12 @@ class StagingInputWrapper(FeedfreeInput):
return self.fetches return self.fetches
def __init__(self, input, devices, nr_stage=5): def __init__(self, input, devices, nr_stage=5):
"""
Args:
input: a :class:`FeedfreeInput`
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
"""
self._input = input self._input = input
assert isinstance(input, FeedfreeInput) assert isinstance(input, FeedfreeInput)
self._devices = devices self._devices = devices
......
...@@ -20,10 +20,11 @@ from .base import Trainer ...@@ -20,10 +20,11 @@ from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper from .input_source import QueueInput, StagingInputWrapper
__all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer'] __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter']
class MultiGPUTrainer(Trainer): class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def build_on_multi_tower(towers, func, devices=None): def build_on_multi_tower(towers, func, devices=None):
...@@ -32,6 +33,9 @@ class MultiGPUTrainer(Trainer): ...@@ -32,6 +33,9 @@ class MultiGPUTrainer(Trainer):
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers. devices: a list of devices to be used. By default will use GPUs in towers.
Returns:
List of outputs of ``func``, evaluated on each tower.
""" """
logger.info("Training a model of {} tower".format(len(towers))) logger.info("Training a model of {} tower".format(len(towers)))
...@@ -58,14 +62,13 @@ class MultiGPUTrainer(Trainer): ...@@ -58,14 +62,13 @@ class MultiGPUTrainer(Trainer):
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class ParamServerDeviceSetter(object): class LeastLoadedDeviceSetter(object):
"""Helper class to assign variables on the least loaded ps-device.""" """ Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices): def __init__(self, worker_device, ps_devices):
""" """
Args: Args:
worker_device: the device to use for computer ops. worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops. Each variable is ps_devices: a list of device to use for Variable ops.
assigned to the least loaded device.
""" """
self.ps_devices = ps_devices self.ps_devices = ps_devices
self.worker_device = worker_device self.worker_device = worker_device
...@@ -86,7 +89,7 @@ class ParamServerDeviceSetter(object): ...@@ -86,7 +89,7 @@ class ParamServerDeviceSetter(object):
return device_name return device_name
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrainer): class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
""" """
A multi-tower multi-GPU trainer which synchronoizes the gradients computed A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored on PS. from each tower, averages them and update to variables stored on PS.
...@@ -148,12 +151,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrai ...@@ -148,12 +151,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrai
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower] raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
if self._ps_device == 'gpu': if self._ps_device == 'gpu':
devices = [ParamServerDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else: else:
devices = [tf.train.replica_device_setter( devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices] worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainer.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices) self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
# debug tower performance (without update): # debug tower performance (without update):
...@@ -175,7 +178,7 @@ def SyncMultiGPUTrainer(config): ...@@ -175,7 +178,7 @@ def SyncMultiGPUTrainer(config):
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu') return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class AsyncMultiGPUTrainer(MultiGPUTrainer, class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
SingleCostFeedfreeTrainer): SingleCostFeedfreeTrainer):
""" """
A multi-tower multi-GPU trainer where each tower independently A multi-tower multi-GPU trainer where each tower independently
...@@ -204,7 +207,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -204,7 +207,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self): def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup() super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list] grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1: if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment