Commit bd686aab authored by Yuxin Wu's avatar Yuxin Wu

deprecate _get_input_vars

parent bbaf8d12
......@@ -97,7 +97,7 @@ class InferenceRunner(Triggerable):
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_reuse_placehdrs()
input_vars = self.trainer.model.get_reused_placehdrs()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
......@@ -198,7 +198,7 @@ class FeedfreeInferenceRunner(Triggerable):
self._input_data._setup(self.trainer)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs()
model_placehdrs = self.trainer.model.get_reused_placehdrs()
if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.")
assert len(self._input_names) == len(self._input_tensors), \
......
......@@ -66,7 +66,7 @@ class LMDBData(RNGDataFlow):
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True.
keys (list[str] or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. ``{:0>8d}`` which will be
formatted with the indices from 0 to *total_size - 1*.
......
......@@ -8,17 +8,17 @@ import tensorflow as tf
import pickle
import six
from ..utils import logger, INPUT_VARS_KEY
from ..utils import logger, INPUTS_KEY
from ..tfutils.gradproc import CheckGradient
from ..tfutils.summary import add_moving_summary
from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
__all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph']
# TODO "variable" is not the right name to use for input here.
# TODO "variable" is not a right name to use across this file.
class InputVar(object):
class InputDesc(object):
""" Store metadata about input placeholders. """
def __init__(self, type, shape, name, sparse=False):
"""
......@@ -41,13 +41,16 @@ class InputVar(object):
return pickle.loads(buf)
InputVar = InputDesc
@six.add_metaclass(ABCMeta)
class ModelDesc(object):
""" Base class for a model description """
def get_input_vars(self):
def get_reused_placehdrs(self):
"""
Create or return (if already created) raw input TF placeholder vars in the graph.
Create or return (if already created) raw input TF placeholders in the graph.
Returns:
list[tf.Tensor]: the list of input placeholders in the graph.
......@@ -58,20 +61,21 @@ class ModelDesc(object):
self.reuse_input_vars = ret
return ret
# alias
get_reuse_placehdrs = get_input_vars
def get_input_vars(self):
logger.warn("[Deprecated] get_input_vars() was renamed to get_reused_placehdrs()!")
return self.get_reused_placehdrs()
def build_placeholders(self, prefix=''):
"""
For each InputVar, create new placeholders with optional prefix and
For each input, create new placeholders with optional prefix and
return them. Useful when building new towers.
Returns:
list[tf.Tensor]: the list of built placeholders.
"""
input_vars = self._get_input_vars()
input_vars = self._get_inputs()
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v.dumps())
tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = []
for v in input_vars:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
......@@ -80,20 +84,21 @@ class ModelDesc(object):
name=prefix + v.name))
return ret
def get_input_vars_desc(self):
def get_inputs_desc(self):
"""
Returns:
list[:class:`InputVar`]: list of the underlying :class:`InputVar`.
list[:class:`InputDesc`]: list of the underlying :class:`InputDesc`.
"""
return self._get_input_vars()
return self._get_inputs()
def _get_input_vars(self): # keep backward compatibility
def _get_inputs(self): # this is a better name than _get_input_vars
"""
:returns: a list of InputVar
:returns: a list of InputDesc
"""
return self._get_inputs()
logger.warn("[Deprecated] _get_input_vars() is renamed to _get_inputs()")
return self._get_input_vars()
def _get_inputs(self): # this is a better name than _get_input_vars
def _get_input_vars(self): # keep backward compatibility
raise NotImplementedError()
def build_graph(self, model_inputs):
......@@ -102,7 +107,7 @@ class ModelDesc(object):
Args:
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to
InputVars of this model.
InputDesc of this model.
"""
self._build_graph(model_inputs)
......@@ -169,14 +174,14 @@ class ModelFromMetaGraph(ModelDesc):
"""
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
for k in [INPUTS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]:
assert k in all_coll, \
"Collection {} not found in metagraph!".format(k)
def _get_inputs(self):
col = tf.get_collection(INPUT_VARS_KEY)
col = [InputVar.loads(v) for v in col]
col = tf.get_collection(INPUTS_KEY)
col = [InputDesc.loads(v) for v in col]
return col
def _build_graph(self, _, __):
......
......@@ -123,7 +123,7 @@ class OfflinePredictor(OnlinePredictor):
"""
self.graph = tf.Graph()
with self.graph.as_default():
input_placehdrs = config.model.get_input_vars()
input_placehdrs = config.model.get_reused_placehdrs()
with TowerContext('', False):
config.model.build_graph(input_placehdrs)
......
......@@ -47,7 +47,7 @@ class PredictConfig(object):
self.input_names = input_names
if self.input_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
raw_vars = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_vars]
self.output_names = output_names
assert_type(self.output_names, list)
......
......@@ -28,7 +28,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
with self.graph.as_default():
# TODO backup summary keys?
def fn(_):
config.model.build_graph(config.model.get_input_vars())
config.model.build_graph(config.model.get_reused_placehdrs())
build_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config)
......
......@@ -39,14 +39,14 @@ class FeedInput(InputData):
return self.ds.size()
def _setup(self, trainer):
self.input_vars = trainer.model.get_input_vars()
self.input_placehdrs = trainer.model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
def next_feed(self):
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
feed = dict(zip(self.input_placehdrs, data))
self._last_feed = feed
return feed
......@@ -134,7 +134,7 @@ class QueueInput(FeedfreeInput):
return self.ds.size()
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars()
self.input_placehdrs = trainer.model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
if self.queue is None:
......@@ -182,7 +182,7 @@ class BatchQueueInput(FeedfreeInput):
return self.ds.size() // self.batch_size
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars()
self.input_placehdrs = trainer.model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
......@@ -194,7 +194,7 @@ class BatchQueueInput(FeedfreeInput):
name=get_op_tensor_name(p.name)[0] + '-nobatch'))
# dequeue_many requires fully-defined shapes
shape_err = "Use of BatchQueueInput requires input variables to have fully-defined "
shape_err = "Use of BatchQueueInput requires inputs to have fully-defined "
"shapes except for the batch dimension"
shapes = []
for p in placehdrs_nobatch:
......@@ -226,7 +226,7 @@ class BatchQueueInput(FeedfreeInput):
class DummyConstantInput(FeedfreeInput):
""" Input some constant variables. Only for debugging performance issues """
""" Input some constant tensor. Only for debugging performance issues """
def __init__(self, shapes):
self.shapes = shapes
......@@ -238,9 +238,10 @@ class DummyConstantInput(FeedfreeInput):
ret = []
for idx, p in enumerate(placehdrs):
with tf.device('/gpu:0'):
ret.append(tf.get_variable('dummy-' + p.op.name,
shape=self.shapes[idx], dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
ret.append(tf.get_variable(
'dummy-' + p.op.name, shape=self.shapes[idx],
dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
return ret
......
......@@ -47,7 +47,7 @@ class PredictorFactory(object):
freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.model.build_graph(self.model.get_input_vars())
self.model.build_graph(self.model.get_reused_placehdrs())
build_prediction_graph(fn, self.towers)
self.tower_built = True
......@@ -79,7 +79,7 @@ class SimpleTrainer(Trainer):
def _setup(self):
self._input_method._setup(self)
model = self.model
self.input_vars = model.get_input_vars()
self.input_vars = model.get_reused_placehdrs()
with TowerContext('', is_training=True):
model.build_graph(self.input_vars)
cost_var = model.get_cost()
......
......@@ -19,8 +19,8 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
INPUT_VARS_KEY = 'INPUT_VARIABLES'
# metainfo for input tensors
INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment