Commit bd686aab authored by Yuxin Wu's avatar Yuxin Wu

deprecate _get_input_vars

parent bbaf8d12
...@@ -97,7 +97,7 @@ class InferenceRunner(Triggerable): ...@@ -97,7 +97,7 @@ class InferenceRunner(Triggerable):
def _find_input_tensors(self): def _find_input_tensors(self):
if self.input_tensors is None: if self.input_tensors is None:
input_vars = self.trainer.model.get_reuse_placehdrs() input_vars = self.trainer.model.get_reused_placehdrs()
# TODO even if it works here, sparse still is unavailable # TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse # because get_tensor_by_name doesn't work for sparse
...@@ -198,7 +198,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -198,7 +198,7 @@ class FeedfreeInferenceRunner(Triggerable):
self._input_data._setup(self.trainer) self._input_data._setup(self.trainer)
# only 1 prediction tower will be used for inference # only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs() model_placehdrs = self.trainer.model.get_reused_placehdrs()
if self._input_names is not None: if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.") raise NotImplementedError("Random code. Not tested.")
assert len(self._input_names) == len(self._input_tensors), \ assert len(self._input_names) == len(self._input_tensors), \
......
...@@ -66,7 +66,7 @@ class LMDBData(RNGDataFlow): ...@@ -66,7 +66,7 @@ class LMDBData(RNGDataFlow):
Args: Args:
lmdb_path (str): a directory or a file. lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not. shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True. keys (list[str] or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. ``{:0>8d}`` which will be It can also be a format string e.g. ``{:0>8d}`` which will be
formatted with the indices from 0 to *total_size - 1*. formatted with the indices from 0 to *total_size - 1*.
......
...@@ -8,17 +8,17 @@ import tensorflow as tf ...@@ -8,17 +8,17 @@ import tensorflow as tf
import pickle import pickle
import six import six
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUTS_KEY
from ..tfutils.gradproc import CheckGradient from ..tfutils.gradproc import CheckGradient
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph'] __all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph']
# TODO "variable" is not the right name to use for input here.
# TODO "variable" is not a right name to use across this file.
class InputVar(object): class InputDesc(object):
""" Store metadata about input placeholders. """ """ Store metadata about input placeholders. """
def __init__(self, type, shape, name, sparse=False): def __init__(self, type, shape, name, sparse=False):
""" """
...@@ -41,13 +41,16 @@ class InputVar(object): ...@@ -41,13 +41,16 @@ class InputVar(object):
return pickle.loads(buf) return pickle.loads(buf)
InputVar = InputDesc
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
def get_input_vars(self): def get_reused_placehdrs(self):
""" """
Create or return (if already created) raw input TF placeholder vars in the graph. Create or return (if already created) raw input TF placeholders in the graph.
Returns: Returns:
list[tf.Tensor]: the list of input placeholders in the graph. list[tf.Tensor]: the list of input placeholders in the graph.
...@@ -58,20 +61,21 @@ class ModelDesc(object): ...@@ -58,20 +61,21 @@ class ModelDesc(object):
self.reuse_input_vars = ret self.reuse_input_vars = ret
return ret return ret
# alias def get_input_vars(self):
get_reuse_placehdrs = get_input_vars logger.warn("[Deprecated] get_input_vars() was renamed to get_reused_placehdrs()!")
return self.get_reused_placehdrs()
def build_placeholders(self, prefix=''): def build_placeholders(self, prefix=''):
""" """
For each InputVar, create new placeholders with optional prefix and For each input, create new placeholders with optional prefix and
return them. Useful when building new towers. return them. Useful when building new towers.
Returns: Returns:
list[tf.Tensor]: the list of built placeholders. list[tf.Tensor]: the list of built placeholders.
""" """
input_vars = self._get_input_vars() input_vars = self._get_inputs()
for v in input_vars: for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v.dumps()) tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = [] ret = []
for v in input_vars: for v in input_vars:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
...@@ -80,20 +84,21 @@ class ModelDesc(object): ...@@ -80,20 +84,21 @@ class ModelDesc(object):
name=prefix + v.name)) name=prefix + v.name))
return ret return ret
def get_input_vars_desc(self): def get_inputs_desc(self):
""" """
Returns: Returns:
list[:class:`InputVar`]: list of the underlying :class:`InputVar`. list[:class:`InputDesc`]: list of the underlying :class:`InputDesc`.
""" """
return self._get_input_vars() return self._get_inputs()
def _get_input_vars(self): # keep backward compatibility def _get_inputs(self): # this is a better name than _get_input_vars
""" """
:returns: a list of InputVar :returns: a list of InputDesc
""" """
return self._get_inputs() logger.warn("[Deprecated] _get_input_vars() is renamed to _get_inputs()")
return self._get_input_vars()
def _get_inputs(self): # this is a better name than _get_input_vars def _get_input_vars(self): # keep backward compatibility
raise NotImplementedError() raise NotImplementedError()
def build_graph(self, model_inputs): def build_graph(self, model_inputs):
...@@ -102,7 +107,7 @@ class ModelDesc(object): ...@@ -102,7 +107,7 @@ class ModelDesc(object):
Args: Args:
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to model_inputs (list[tf.Tensor]): a list of inputs, corresponding to
InputVars of this model. InputDesc of this model.
""" """
self._build_graph(model_inputs) self._build_graph(model_inputs)
...@@ -169,14 +174,14 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -169,14 +174,14 @@ class ModelFromMetaGraph(ModelDesc):
""" """
tf.train.import_meta_graph(filename) tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys() all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES, for k in [INPUTS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]: tf.GraphKeys.GLOBAL_VARIABLES]:
assert k in all_coll, \ assert k in all_coll, \
"Collection {} not found in metagraph!".format(k) "Collection {} not found in metagraph!".format(k)
def _get_inputs(self): def _get_inputs(self):
col = tf.get_collection(INPUT_VARS_KEY) col = tf.get_collection(INPUTS_KEY)
col = [InputVar.loads(v) for v in col] col = [InputDesc.loads(v) for v in col]
return col return col
def _build_graph(self, _, __): def _build_graph(self, _, __):
......
...@@ -123,7 +123,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -123,7 +123,7 @@ class OfflinePredictor(OnlinePredictor):
""" """
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
input_placehdrs = config.model.get_input_vars() input_placehdrs = config.model.get_reused_placehdrs()
with TowerContext('', False): with TowerContext('', False):
config.model.build_graph(input_placehdrs) config.model.build_graph(input_placehdrs)
......
...@@ -47,7 +47,7 @@ class PredictConfig(object): ...@@ -47,7 +47,7 @@ class PredictConfig(object):
self.input_names = input_names self.input_names = input_names
if self.input_names is None: if self.input_names is None:
# neither options is set, assume all inputs # neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc() raw_vars = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_vars] self.input_names = [k.name for k in raw_vars]
self.output_names = output_names self.output_names = output_names
assert_type(self.output_names, list) assert_type(self.output_names, list)
......
...@@ -28,7 +28,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -28,7 +28,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
with self.graph.as_default(): with self.graph.as_default():
# TODO backup summary keys? # TODO backup summary keys?
def fn(_): def fn(_):
config.model.build_graph(config.model.get_input_vars()) config.model.build_graph(config.model.get_reused_placehdrs())
build_prediction_graph(fn, towers) build_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config) self.sess = tf.Session(config=config.session_config)
......
...@@ -39,14 +39,14 @@ class FeedInput(InputData): ...@@ -39,14 +39,14 @@ class FeedInput(InputData):
return self.ds.size() return self.ds.size()
def _setup(self, trainer): def _setup(self, trainer):
self.input_vars = trainer.model.get_input_vars() self.input_placehdrs = trainer.model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1) rds = RepeatedData(self.ds, -1)
rds.reset_state() rds.reset_state()
self.data_producer = rds.get_data() self.data_producer = rds.get_data()
def next_feed(self): def next_feed(self):
data = next(self.data_producer) data = next(self.data_producer)
feed = dict(zip(self.input_vars, data)) feed = dict(zip(self.input_placehdrs, data))
self._last_feed = feed self._last_feed = feed
return feed return feed
...@@ -134,7 +134,7 @@ class QueueInput(FeedfreeInput): ...@@ -134,7 +134,7 @@ class QueueInput(FeedfreeInput):
return self.ds.size() return self.ds.size()
def _setup(self, trainer): def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars() self.input_placehdrs = trainer.model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!" "QueueInput can only be used with input placeholders!"
if self.queue is None: if self.queue is None:
...@@ -182,7 +182,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -182,7 +182,7 @@ class BatchQueueInput(FeedfreeInput):
return self.ds.size() // self.batch_size return self.ds.size() // self.batch_size
def _setup(self, trainer): def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars() self.input_placehdrs = trainer.model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!" "QueueInput can only be used with input placeholders!"
...@@ -194,7 +194,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -194,7 +194,7 @@ class BatchQueueInput(FeedfreeInput):
name=get_op_tensor_name(p.name)[0] + '-nobatch')) name=get_op_tensor_name(p.name)[0] + '-nobatch'))
# dequeue_many requires fully-defined shapes # dequeue_many requires fully-defined shapes
shape_err = "Use of BatchQueueInput requires input variables to have fully-defined " shape_err = "Use of BatchQueueInput requires inputs to have fully-defined "
"shapes except for the batch dimension" "shapes except for the batch dimension"
shapes = [] shapes = []
for p in placehdrs_nobatch: for p in placehdrs_nobatch:
...@@ -226,7 +226,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -226,7 +226,7 @@ class BatchQueueInput(FeedfreeInput):
class DummyConstantInput(FeedfreeInput): class DummyConstantInput(FeedfreeInput):
""" Input some constant variables. Only for debugging performance issues """ """ Input some constant tensor. Only for debugging performance issues """
def __init__(self, shapes): def __init__(self, shapes):
self.shapes = shapes self.shapes = shapes
...@@ -238,9 +238,10 @@ class DummyConstantInput(FeedfreeInput): ...@@ -238,9 +238,10 @@ class DummyConstantInput(FeedfreeInput):
ret = [] ret = []
for idx, p in enumerate(placehdrs): for idx, p in enumerate(placehdrs):
with tf.device('/gpu:0'): with tf.device('/gpu:0'):
ret.append(tf.get_variable('dummy-' + p.op.name, ret.append(tf.get_variable(
shape=self.shapes[idx], dtype=p.dtype, trainable=False, 'dummy-' + p.op.name, shape=self.shapes[idx],
initializer=tf.constant_initializer())) dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
return ret return ret
......
...@@ -47,7 +47,7 @@ class PredictorFactory(object): ...@@ -47,7 +47,7 @@ class PredictorFactory(object):
freeze_collection(SUMMARY_BACKUP_KEYS), \ freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.variable_scope(tf.get_variable_scope(), reuse=True): tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_): def fn(_):
self.model.build_graph(self.model.get_input_vars()) self.model.build_graph(self.model.get_reused_placehdrs())
build_prediction_graph(fn, self.towers) build_prediction_graph(fn, self.towers)
self.tower_built = True self.tower_built = True
...@@ -79,7 +79,7 @@ class SimpleTrainer(Trainer): ...@@ -79,7 +79,7 @@ class SimpleTrainer(Trainer):
def _setup(self): def _setup(self):
self._input_method._setup(self) self._input_method._setup(self)
model = self.model model = self.model
self.input_vars = model.get_input_vars() self.input_vars = model.get_reused_placehdrs()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(self.input_vars) model.build_graph(self.input_vars)
cost_var = model.get_cost() cost_var = model.get_cost()
......
...@@ -19,8 +19,8 @@ PREDICT_TOWER = 'towerp' ...@@ -19,8 +19,8 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables # metainfo for input tensors
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment