Commit e6493857 authored by Yuxin Wu's avatar Yuxin Wu

StagingInputWrapper takes a list of int

parent 4c5cdf9b
......@@ -136,7 +136,7 @@ class MultiGPUGANTrainer(Trainer):
raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
# setup input
input = StagingInputWrapper(QueueInput(config.dataflow), raw_devices)
input = StagingInputWrapper(QueueInput(config.dataflow), config.tower)
model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
......
......@@ -2,6 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os as _os
from tensorpack.libinfo import __version__, _HAS_TF
......@@ -15,7 +16,11 @@ if _HAS_TF:
from tensorpack.callbacks import *
from tensorpack.tfutils import *
from tensorpack.train import *
# In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.input_source import *
from tensorpack.predict import *
......@@ -19,6 +19,7 @@ from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated
from ..callbacks.base import Callback
from ..callbacks.graph import RunOp
......@@ -457,7 +458,8 @@ class TFDatasetInput(FeedfreeInput):
class StagingInputWrapper(FeedfreeInput):
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
A wrapper around a feedfree input,
to prefetch the input in StagingArea (on GPUs).
"""
class StagingCallback(Callback):
"""
......@@ -478,16 +480,22 @@ class StagingInputWrapper(FeedfreeInput):
def _before_run(self, ctx):
return self.fetches
def __init__(self, input, devices, nr_stage=5):
def __init__(self, input, towers, nr_stage=5):
"""
Args:
input: a :class:`FeedfreeInput`
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
input (FeedfreeInput):
towers ([int]): list of GPU ids to prefetch on.
nr_stage: number of elements to prefetch on each GPU.
"""
assert isinstance(input, FeedfreeInput), input
self._input = input
self._devices = devices
if not isinstance(towers[0], int):
# API changed
log_deprecated("StagingInputWrapper(devices=)", "Use (towers=) instead!", "2018-01-31")
self._devices = towers
else:
self._devices = ['/gpu:{}'.format(k) for k in towers]
self._nr_stage = nr_stage
self._areas = []
self._stage_ops = []
......
......@@ -44,8 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
# seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInputWrapper, DummyConstantInput)):
devices = ['/gpu:{}'.format(k) for k in config.tower]
config.data = StagingInputWrapper(config.data, devices)
config.data = StagingInputWrapper(config.data, config.tower)
class SyncMultiGPUTrainerParameterServer(Trainer):
......
......@@ -62,9 +62,4 @@ def QueueInputTrainer(config, input_queue=None):
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleTrainer(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment