Commit 206e1a67 authored by Yuxin Wu's avatar Yuxin Wu

some docs change

parent 843ab15c
......@@ -153,7 +153,7 @@ class QueueInput(FeedfreeInput):
def size(self):
return self.ds.size()
# TODO XXX use input data mapping. not all placeholders are needed
# TODO use input data mapping. not all placeholders are needed
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
......@@ -335,7 +335,14 @@ class ZMQInput(FeedfreeInput):
class StagingInputWrapper(FeedfreeInput):
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
"""
class StagingCallback(Callback):
"""
A callback registered by this input source, to make sure stage/unstage
is run at each step.
"""
def __init__(self, stage_op, unstage_op, nr_stage):
self.nr_stage = nr_stage
self.stage_op = stage_op
......@@ -351,6 +358,12 @@ class StagingInputWrapper(FeedfreeInput):
return self.fetches
def __init__(self, input, devices, nr_stage=5):
"""
Args:
input: a :class:`FeedfreeInput`
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
"""
self._input = input
assert isinstance(input, FeedfreeInput)
self._devices = devices
......
......@@ -20,10 +20,11 @@ from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper
__all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter']
class MultiGPUTrainer(Trainer):
class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def build_on_multi_tower(towers, func, devices=None):
......@@ -32,6 +33,9 @@ class MultiGPUTrainer(Trainer):
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
Returns:
List of outputs of ``func``, evaluated on each tower.
"""
logger.info("Training a model of {} tower".format(len(towers)))
......@@ -58,14 +62,13 @@ class MultiGPUTrainer(Trainer):
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class ParamServerDeviceSetter(object):
"""Helper class to assign variables on the least loaded ps-device."""
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for computer ops.
ps_devices: a list of device to use for Variable ops. Each variable is
assigned to the least loaded device.
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
......@@ -86,7 +89,7 @@ class ParamServerDeviceSetter(object):
return device_name
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrainer):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored on PS.
......@@ -148,12 +151,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrai
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
if self._ps_device == 'gpu':
devices = [ParamServerDeviceSetter(d, raw_devices) for d in raw_devices]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else:
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainer.build_on_multi_tower(
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
# debug tower performance (without update):
......@@ -175,7 +178,7 @@ def SyncMultiGPUTrainer(config):
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class AsyncMultiGPUTrainer(MultiGPUTrainer,
class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
SingleCostFeedfreeTrainer):
"""
A multi-tower multi-GPU trainer where each tower independently
......@@ -204,7 +207,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer.build_on_multi_tower(
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment