Commit 079eb3a9 authored by Yuxin Wu's avatar Yuxin Wu

Make SimpleTrainer support InputSource

parent 20ee19bc
......@@ -173,4 +173,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleFeedfreeTrainer(config).train()
SimpleTrainer(config).train()
......@@ -4,9 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger
from ..utils.develop import deprecated
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .simple import SimpleTrainer
from .base import Trainer
__all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
......@@ -34,6 +36,7 @@ class FeedfreeTrainerBase(Trainer):
self.hooked_sess.run(self.train_op)
# TODO Kept for now for back-compat
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
......@@ -42,30 +45,10 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
return self.model.get_cost_and_grad()
class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
"""
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
"""
def __init__(self, config):
"""
Args:
config (TrainConfig): ``config.data`` must exist and is a :class:`FeedfreeInput`.
"""
self._input_source = config.data
assert isinstance(self._input_source, FeedfreeInput), self._input_source
super(SimpleFeedfreeTrainer, self).__init__(config)
assert len(self.config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(self.config.tower))
def _setup(self):
super(SimpleFeedfreeTrainer, self)._setup()
with TowerContext('', is_training=True):
cost, grads = self._get_cost_and_grad()
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
@deprecated("Use SimpleTrainer with config.data instead!")
def SimpleFeedfreeTrainer(config):
assert isinstance(config.data, FeedfreeInput), config.data
return SimpleTrainer(config)
def QueueInputTrainer(config, input_queue=None):
......@@ -76,16 +59,16 @@ def QueueInputTrainer(config, input_queue=None):
Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleFeedfreeTrainer(config)
return SimpleTrainer(config)
......@@ -13,8 +13,10 @@ __all__ = ['SimpleTrainer']
class SimpleTrainer(Trainer):
""" A naive demo trainer which iterates over a DataFlow and feed into the
graph. It's not efficient compared to QueueInputTrainer or others."""
""" A naive single-tower single-cost demo trainer.
Support both InputSource and DataFlow.
When DataFlow is given, the InputSource to be used will be ``FeedInput(df)``.
"""
def __init__(self, config):
"""
......@@ -22,23 +24,25 @@ class SimpleTrainer(Trainer):
config (TrainConfig): the training config.
"""
super(SimpleTrainer, self).__init__(config)
assert len(self.config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(self.config.tower))
if config.dataflow is None:
self._input_source = config.data
assert isinstance(self._input_source, FeedInput), type(self._input_source)
else:
self._input_source = FeedInput(config.dataflow)
logger.warn("SimpleTrainer is slow! Do you really want to use it?")
logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead.")
def run_step(self):
""" Feed data into the graph and run the updates. """
self.hooked_sess.run(self.train_op)
def _setup(self):
self._setup_input_source(self._input_source)
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
cost_var = self.model.get_cost()
cost, grads = self.model.get_cost_and_grad()
opt = self.model.get_optimizer()
self.train_op = opt.minimize(cost_var, name='min_op')
self.train_op = opt.apply_gradients(grads, name='min_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment