Commit ad5cb725 authored by Yuxin Wu's avatar Yuxin Wu

remove feedfree

parent 82bf74c9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: feedfree.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger
from ..utils.develop import deprecated
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .simple import SimpleTrainer
from .base import Trainer
__all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer', 'QueueInputTrainer']
# TODO deprecate it some time
class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``config.data`` to be a :class:`FeedfreeInput`.
"""
@deprecated("Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)")
def build_train_tower(self):
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
@deprecated("", "2017-11-21")
def __init__(self, *args, **kwargs):
super(SingleCostFeedfreeTrainer, self).__init__(*args, **kwargs)
logger.warn("SingleCostFeedfreeTrainer was deprecated!")
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.model.build_graph(self._input_source)
return self.model.get_cost_and_grad()
def QueueInputTrainer(config, input_queue=None):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleTrainer(config)
......@@ -6,7 +6,7 @@
from .base import Trainer
from ..utils import logger
from ..graph_builder.input_source import FeedInput
from ..graph_builder.input_source import FeedInput, QueueInput
from ..graph_builder.training import SimpleGraphBuilder
__all__ = ['SimpleTrainer']
......@@ -39,29 +39,35 @@ class SimpleTrainer(Trainer):
"Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config)
@staticmethod
def setup_graph(model, input):
"""
Setup graph for SimpleTrainer. It simply build one tower and optimize `model.cost`.
Args:
model (ModelDesc):
input (InputSource):
def _setup(self):
cbs = self._input_source.setup(self.model.get_inputs_desc())
Returns:
tf.Operation: the training op
def get_cost(*inputs):
self.model.build_graph(inputs)
return self.model.get_cost()
[Callback]: the callbacks to be added
"""
cbs = input.setup(model.get_inputs_desc())
self.train_op = SimpleGraphBuilder().build(self._input_source, get_cost, self.model.get_optimizer)
self.config.callbacks.extend(cbs)
def get_cost(*inputs):
model.build_graph(inputs)
return model.get_cost()
train_op = SimpleGraphBuilder().build(input, get_cost, model.get_optimizer)
return train_op, cbs
def QueueInputTrainer(config, input_queue=None):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
def _setup(self):
self.train_op, callbacks = SimpleTrainer.setup_graph(self.model, self._input_source)
self.config.callbacks.extend(callbacks)
Args:
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleTrainer(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment