Commit a349c558 authored by Yuxin Wu's avatar Yuxin Wu

small internal rename

parent 6c68f8aa
......@@ -12,7 +12,7 @@ from ..utils import logger, get_tqdm
from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name
from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput
from ..train.input_data import TensorInput
from ..predict import PredictorTowerBuilder
from .base import Triggerable
......@@ -161,7 +161,7 @@ class FeedfreeInferenceRunner(Triggerable):
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
assert isinstance(input, FeedfreeInput), input
assert isinstance(input, TensorInput), input
self._input_data = input
if not isinstance(infs, list):
self.infs = [infs]
......@@ -192,7 +192,7 @@ class FeedfreeInferenceRunner(Triggerable):
self._find_output_tensors()
def _find_input_tensors(self):
self._input_data._setup(self.trainer)
self._input_data.setup(self.trainer.model)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reused_placehdrs()
......
......@@ -36,7 +36,7 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
self._input_method.setup_training(self)
def run_step(self):
""" Simply run ``self.train_op``."""
......
......@@ -22,8 +22,13 @@ __all__ = ['InputData', 'FeedfreeInput',
@six.add_metaclass(ABCMeta)
class InputData(object):
""" Base class for the abstract InputData. """
def setup(self, model):
pass
def setup_training(self, trainer):
self.setup(trainer.model)
class FeedInput(InputData):
""" Input by iterating over a DataFlow and feed datapoints. """
......@@ -38,8 +43,8 @@ class FeedInput(InputData):
def size(self):
return self.ds.size()
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_reused_placehdrs()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
......@@ -58,18 +63,16 @@ class FeedfreeInput(InputData):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
@abstractmethod
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
"""
always create and return a list of new input tensors
"""
def get_client_threads(self):
return []
class EnqueueThread(ShareSessionThread):
......@@ -125,18 +128,21 @@ class QueueInput(FeedfreeInput):
return self.ds.size()
# TODO XXX use input data mapping. not all placeholders are needed
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_reused_placehdrs()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
"QueueInput has to be used with input placeholders!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
def setup_training(self, trainer):
self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
......@@ -166,10 +172,10 @@ class BatchQueueInput(FeedfreeInput):
def size(self):
return self.ds.size() // self.batch_size
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_reused_placehdrs()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
"BatchQueueInput has to be used with input placeholders!"
# prepare placeholders without the first dimension
placehdrs_nobatch = []
......@@ -195,9 +201,12 @@ class BatchQueueInput(FeedfreeInput):
assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def setup_training(self, trainer):
self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
def get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
......@@ -221,7 +230,7 @@ class DummyConstantInput(FeedfreeInput):
self.shapes = shapes
logger.warn("Using dummy input for debug!")
def _get_input_tensors(self):
def get_input_tensors(self):
placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs)
ret = []
......@@ -253,8 +262,5 @@ class TensorInput(FeedfreeInput):
raise NotImplementedError("size of TensorInput is undefined!")
return self._size
def _setup(self, trainer):
pass
def _get_input_tensors(self):
def get_input_tensors(self):
return self.get_tensor_fn()
......@@ -32,13 +32,12 @@ class SimpleTrainer(Trainer):
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self):
self._input_method._setup(self)
self._input_method.setup_training(self)
model = self.model
self.input_vars = model.get_reused_placehdrs()
self.inputs = model.get_reused_placehdrs()
with TowerContext('', is_training=True):
model.build_graph(self.input_vars)
model.build_graph(self.inputs)
cost_var = model.get_cost()
opt = self.config.optimizer
grads = opt.compute_gradients(cost_var)
self.train_op = opt.apply_gradients(grads, name='min_op')
self.train_op = opt.minimize(cost_var, name='min_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment