Commit a349c558 authored by Yuxin Wu's avatar Yuxin Wu

small internal rename

parent 6c68f8aa
...@@ -12,7 +12,7 @@ from ..utils import logger, get_tqdm ...@@ -12,7 +12,7 @@ from ..utils import logger, get_tqdm
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils import TowerContext from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput from ..train.input_data import TensorInput
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Triggerable from .base import Triggerable
...@@ -161,7 +161,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -161,7 +161,7 @@ class FeedfreeInferenceRunner(Triggerable):
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, TensorInput), input
self._input_data = input self._input_data = input
if not isinstance(infs, list): if not isinstance(infs, list):
self.infs = [infs] self.infs = [infs]
...@@ -192,7 +192,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -192,7 +192,7 @@ class FeedfreeInferenceRunner(Triggerable):
self._find_output_tensors() self._find_output_tensors()
def _find_input_tensors(self): def _find_input_tensors(self):
self._input_data._setup(self.trainer) self._input_data.setup(self.trainer.model)
# only 1 prediction tower will be used for inference # only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reused_placehdrs() model_placehdrs = self.trainer.model.get_reused_placehdrs()
......
...@@ -36,7 +36,7 @@ class FeedfreeTrainerBase(Trainer): ...@@ -36,7 +36,7 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self): def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method) assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self) self._input_method.setup_training(self)
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``.""" """ Simply run ``self.train_op``."""
......
...@@ -22,7 +22,12 @@ __all__ = ['InputData', 'FeedfreeInput', ...@@ -22,7 +22,12 @@ __all__ = ['InputData', 'FeedfreeInput',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InputData(object): class InputData(object):
""" Base class for the abstract InputData. """ """ Base class for the abstract InputData. """
pass
def setup(self, model):
pass
def setup_training(self, trainer):
self.setup(trainer.model)
class FeedInput(InputData): class FeedInput(InputData):
...@@ -38,8 +43,8 @@ class FeedInput(InputData): ...@@ -38,8 +43,8 @@ class FeedInput(InputData):
def size(self): def size(self):
return self.ds.size() return self.ds.size()
def _setup(self, trainer): def setup(self, model):
self.input_placehdrs = trainer.model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1) rds = RepeatedData(self.ds, -1)
rds.reset_state() rds.reset_state()
self.data_producer = rds.get_data() self.data_producer = rds.get_data()
...@@ -58,18 +63,16 @@ class FeedfreeInput(InputData): ...@@ -58,18 +63,16 @@ class FeedfreeInput(InputData):
""" Abstract base for input without feed, """ Abstract base for input without feed,
e.g. by queue or other operations. """ e.g. by queue or other operations. """
@abstractmethod
def get_input_tensors(self): def get_input_tensors(self):
""" """
Returns: Returns:
list: A list of tensors corresponding to the inputs of the model. list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
""" """
return self._get_input_tensors()
@abstractmethod def get_client_threads(self):
def _get_input_tensors(self): return []
"""
always create and return a list of new input tensors
"""
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
...@@ -125,18 +128,21 @@ class QueueInput(FeedfreeInput): ...@@ -125,18 +128,21 @@ class QueueInput(FeedfreeInput):
return self.ds.size() return self.ds.size()
# TODO XXX use input data mapping. not all placeholders are needed # TODO XXX use input data mapping. not all placeholders are needed
def _setup(self, trainer): def setup(self, model):
self.input_placehdrs = trainer.model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!" "QueueInput has to be used with input placeholders!"
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs], 50, [x.dtype for x in self.input_placehdrs],
name='input_queue') name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
def setup_training(self, trainer):
self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self): def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
...@@ -166,10 +172,10 @@ class BatchQueueInput(FeedfreeInput): ...@@ -166,10 +172,10 @@ class BatchQueueInput(FeedfreeInput):
def size(self): def size(self):
return self.ds.size() // self.batch_size return self.ds.size() // self.batch_size
def _setup(self, trainer): def setup(self, model):
self.input_placehdrs = trainer.model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!" "BatchQueueInput has to be used with input placeholders!"
# prepare placeholders without the first dimension # prepare placeholders without the first dimension
placehdrs_nobatch = [] placehdrs_nobatch = []
...@@ -195,9 +201,12 @@ class BatchQueueInput(FeedfreeInput): ...@@ -195,9 +201,12 @@ class BatchQueueInput(FeedfreeInput):
assert shp.is_fully_defined(), shape_err assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def setup_training(self, trainer):
self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self): def get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque') ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
...@@ -221,7 +230,7 @@ class DummyConstantInput(FeedfreeInput): ...@@ -221,7 +230,7 @@ class DummyConstantInput(FeedfreeInput):
self.shapes = shapes self.shapes = shapes
logger.warn("Using dummy input for debug!") logger.warn("Using dummy input for debug!")
def _get_input_tensors(self): def get_input_tensors(self):
placehdrs = self.input_placehdrs placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs) assert len(self.shapes) == len(placehdrs)
ret = [] ret = []
...@@ -253,8 +262,5 @@ class TensorInput(FeedfreeInput): ...@@ -253,8 +262,5 @@ class TensorInput(FeedfreeInput):
raise NotImplementedError("size of TensorInput is undefined!") raise NotImplementedError("size of TensorInput is undefined!")
return self._size return self._size
def _setup(self, trainer): def get_input_tensors(self):
pass
def _get_input_tensors(self):
return self.get_tensor_fn() return self.get_tensor_fn()
...@@ -32,13 +32,12 @@ class SimpleTrainer(Trainer): ...@@ -32,13 +32,12 @@ class SimpleTrainer(Trainer):
self.hooked_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
self._input_method._setup(self) self._input_method.setup_training(self)
model = self.model model = self.model
self.input_vars = model.get_reused_placehdrs() self.inputs = model.get_reused_placehdrs()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(self.input_vars) model.build_graph(self.inputs)
cost_var = model.get_cost() cost_var = model.get_cost()
opt = self.config.optimizer opt = self.config.optimizer
grads = opt.compute_gradients(cost_var) self.train_op = opt.minimize(cost_var, name='min_op')
self.train_op = opt.apply_gradients(grads, name='min_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment