Commit 079eb3a9 authored by Yuxin Wu's avatar Yuxin Wu

Make SimpleTrainer support InputSource

parent 20ee19bc
...@@ -173,4 +173,4 @@ if __name__ == '__main__': ...@@ -173,4 +173,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleFeedfreeTrainer(config).train() SimpleTrainer(config).train()
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import QueueInput, FeedfreeInput from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .simple import SimpleTrainer
from .base import Trainer from .base import Trainer
__all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer', __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
...@@ -34,6 +36,7 @@ class FeedfreeTrainerBase(Trainer): ...@@ -34,6 +36,7 @@ class FeedfreeTrainerBase(Trainer):
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
# TODO Kept for now for back-compat
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """ """ A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
...@@ -42,30 +45,10 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -42,30 +45,10 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
return self.model.get_cost_and_grad() return self.model.get_cost_and_grad()
class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer): @deprecated("Use SimpleTrainer with config.data instead!")
""" def SimpleFeedfreeTrainer(config):
A trainer with single cost, single training tower, any number of assert isinstance(config.data, FeedfreeInput), config.data
prediction tower, and feed-free input. return SimpleTrainer(config)
"""
def __init__(self, config):
"""
Args:
config (TrainConfig): ``config.data`` must exist and is a :class:`FeedfreeInput`.
"""
self._input_source = config.data
assert isinstance(self._input_source, FeedfreeInput), self._input_source
super(SimpleFeedfreeTrainer, self).__init__(config)
assert len(self.config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(self.config.tower))
def _setup(self):
super(SimpleFeedfreeTrainer, self)._setup()
with TowerContext('', is_training=True):
cost, grads = self._get_cost_and_grad()
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
def QueueInputTrainer(config, input_queue=None): def QueueInputTrainer(config, input_queue=None):
...@@ -76,16 +59,16 @@ def QueueInputTrainer(config, input_queue=None): ...@@ -76,16 +59,16 @@ def QueueInputTrainer(config, input_queue=None):
Args: Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist. config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
:class:`QueueInput` default.
""" """
if config.data is not None: if config.data is not None:
assert isinstance(config.data, QueueInput), config.data assert isinstance(config.data, QueueInput), config.data
else: else:
config.data = QueueInput(config.dataflow, input_queue) config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
# debug # debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput # from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0']) # config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]]) # config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleFeedfreeTrainer(config) return SimpleTrainer(config)
...@@ -13,8 +13,10 @@ __all__ = ['SimpleTrainer'] ...@@ -13,8 +13,10 @@ __all__ = ['SimpleTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
""" A naive demo trainer which iterates over a DataFlow and feed into the """ A naive single-tower single-cost demo trainer.
graph. It's not efficient compared to QueueInputTrainer or others.""" Support both InputSource and DataFlow.
When DataFlow is given, the InputSource to be used will be ``FeedInput(df)``.
"""
def __init__(self, config): def __init__(self, config):
""" """
...@@ -22,23 +24,25 @@ class SimpleTrainer(Trainer): ...@@ -22,23 +24,25 @@ class SimpleTrainer(Trainer):
config (TrainConfig): the training config. config (TrainConfig): the training config.
""" """
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
assert len(self.config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(self.config.tower))
if config.dataflow is None: if config.dataflow is None:
self._input_source = config.data self._input_source = config.data
assert isinstance(self._input_source, FeedInput), type(self._input_source)
else: else:
self._input_source = FeedInput(config.dataflow) self._input_source = FeedInput(config.dataflow)
logger.warn("SimpleTrainer is slow! Do you really want to use it?") logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead.")
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
def _setup(self): def _setup(self):
self._setup_input_source(self._input_source) self._setup_input_source(self._input_source)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
cost_var = self.model.get_cost() cost, grads = self.model.get_cost_and_grad()
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.train_op = opt.minimize(cost_var, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment