Commit e6493857 authored by Yuxin Wu's avatar Yuxin Wu

StagingInputWrapper takes a list of int

parent 4c5cdf9b
...@@ -136,7 +136,7 @@ class MultiGPUGANTrainer(Trainer): ...@@ -136,7 +136,7 @@ class MultiGPUGANTrainer(Trainer):
raw_devices = ['/gpu:{}'.format(k) for k in config.tower] raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
# setup input # setup input
input = StagingInputWrapper(QueueInput(config.dataflow), raw_devices) input = StagingInputWrapper(QueueInput(config.dataflow), config.tower)
model = config.model model = config.model
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs) config.callbacks.extend(cbs)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os as _os
from tensorpack.libinfo import __version__, _HAS_TF from tensorpack.libinfo import __version__, _HAS_TF
...@@ -15,6 +16,10 @@ if _HAS_TF: ...@@ -15,6 +16,10 @@ if _HAS_TF:
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.tfutils import * from tensorpack.tfutils import *
# In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import * from tensorpack.train import *
from tensorpack.graph_builder import * from tensorpack.graph_builder import *
from tensorpack.input_source import * from tensorpack.input_source import *
......
...@@ -19,6 +19,7 @@ from ..tfutils.common import get_op_tensor_name ...@@ -19,6 +19,7 @@ from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated
from ..callbacks.base import Callback from ..callbacks.base import Callback
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
...@@ -457,7 +458,8 @@ class TFDatasetInput(FeedfreeInput): ...@@ -457,7 +458,8 @@ class TFDatasetInput(FeedfreeInput):
class StagingInputWrapper(FeedfreeInput): class StagingInputWrapper(FeedfreeInput):
""" """
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs). A wrapper around a feedfree input,
to prefetch the input in StagingArea (on GPUs).
""" """
class StagingCallback(Callback): class StagingCallback(Callback):
""" """
...@@ -478,16 +480,22 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -478,16 +480,22 @@ class StagingInputWrapper(FeedfreeInput):
def _before_run(self, ctx): def _before_run(self, ctx):
return self.fetches return self.fetches
def __init__(self, input, devices, nr_stage=5): def __init__(self, input, towers, nr_stage=5):
""" """
Args: Args:
input: a :class:`FeedfreeInput` input (FeedfreeInput):
devices: list of devices to be used for each training tower towers ([int]): list of GPU ids to prefetch on.
nr_stage: number of elements to prefetch nr_stage: number of elements to prefetch on each GPU.
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
self._input = input self._input = input
self._devices = devices if not isinstance(towers[0], int):
# API changed
log_deprecated("StagingInputWrapper(devices=)", "Use (towers=) instead!", "2018-01-31")
self._devices = towers
else:
self._devices = ['/gpu:{}'.format(k) for k in towers]
self._nr_stage = nr_stage self._nr_stage = nr_stage
self._areas = [] self._areas = []
self._stage_ops = [] self._stage_ops = []
......
...@@ -44,8 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True): ...@@ -44,8 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInputWrapper, DummyConstantInput)): if not isinstance(config.data, (StagingInputWrapper, DummyConstantInput)):
devices = ['/gpu:{}'.format(k) for k in config.tower] config.data = StagingInputWrapper(config.data, config.tower)
config.data = StagingInputWrapper(config.data, devices)
class SyncMultiGPUTrainerParameterServer(Trainer): class SyncMultiGPUTrainerParameterServer(Trainer):
......
...@@ -62,9 +62,4 @@ def QueueInputTrainer(config, input_queue=None): ...@@ -62,9 +62,4 @@ def QueueInputTrainer(config, input_queue=None):
else: else:
config.data = QueueInput(config.dataflow, input_queue) config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None config.dataflow = None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleTrainer(config) return SimpleTrainer(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment