Commit b8a50d72 authored by Yuxin Wu's avatar Yuxin Wu

InputDesc -> tf.TensorSpec everywhere

parent ba679ab1
...@@ -8,6 +8,9 @@ so you don't need to look at here very often. ...@@ -8,6 +8,9 @@ so you don't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
+ [2019/03/20] The concept of `InputDesc` was replaced by its equivalent in TF:
`tf.TensorSpec`. This may be a breaking change if you have customized
code that relies on internals of `InputDesc`.
+ [2018/08/27] msgpack is used again for "serialization to disk", because pyarrow + [2018/08/27] msgpack is used again for "serialization to disk", because pyarrow
has no compatibility between versions. To use pyarrow instead, `export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`. has no compatibility between versions. To use pyarrow instead, `export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`.
+ [2018/04/05] msgpack is replaced by pyarrow in favor of its speed. If you want old behavior, + [2018/04/05] msgpack is replaced by pyarrow in favor of its speed. If you want old behavior,
......
...@@ -375,6 +375,8 @@ _DEPRECATED_NAMES = set([ ...@@ -375,6 +375,8 @@ _DEPRECATED_NAMES = set([
'PrefetchOnGPUs', 'PrefetchOnGPUs',
'DistributedTrainerReplicated', 'DistributedTrainerReplicated',
'DistributedTrainerParameterServer', 'DistributedTrainerParameterServer',
'InputDesc',
'inputs_desc',
# renamed items that should not appear in docs # renamed items that should not appear in docs
'DumpTensor', 'DumpTensor',
......
...@@ -48,7 +48,7 @@ Most neural network training tasks are single-cost optimization. ...@@ -48,7 +48,7 @@ Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks. Tensorpack provides some trainer implementations for such tasks.
These trainers will take care of step 1 (define the graph), with the following arguments: These trainers will take care of step 1 (define the graph), with the following arguments:
1. Some `InputDesc`, the metadata about the input. 1. Some `tf.TensorSpec`, the signature of the input.
2. An `InputSource`, where the input come from. See [Input Pipeline](input-source.html). 2. An `InputSource`, where the input come from. See [Input Pipeline](input-source.html).
3. A function which takes input tensors and returns the cost. 3. A function which takes input tensors and returns the cost.
4. A function which returns an optimizer. 4. A function which returns an optimizer.
......
...@@ -11,7 +11,7 @@ This interface is enough for most types of single-cost tasks. ...@@ -11,7 +11,7 @@ This interface is enough for most types of single-cost tasks.
A lot of examples are written in this interface. A lot of examples are written in this interface.
[SingleCost trainers](../modules/train.html#tensorpack.train.SingleCostTrainer) [SingleCost trainers](../modules/train.html#tensorpack.train.SingleCostTrainer)
expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost function, and an optimizer. expects 4 arguments to setup the graph: input signatures, `InputSource`, get_cost function, and an optimizer.
`ModelDesc` describes a model by packing 3 of them together into one object: `ModelDesc` describes a model by packing 3 of them together into one object:
```python ```python
...@@ -62,7 +62,7 @@ The function `launch_train_with_config(config, trainer)` ...@@ -62,7 +62,7 @@ The function `launch_train_with_config(config, trainer)`
uses the raw trainer interface under the hood, and is almost equivalent to the following two lines of code: uses the raw trainer interface under the hood, and is almost equivalent to the following two lines of code:
```python ```python
trainer.setup_graph( trainer.setup_graph(
my_model.get_inputs_desc(), my_model.get_input_signature(),
my_input_source, # or QueueInput(my_dataflow) my_input_source, # or QueueInput(my_dataflow)
my_model.build_graph, my_model.build_graph,
my_model.get_optimizer) my_model.get_optimizer)
......
...@@ -42,7 +42,7 @@ def tower_func(image): ...@@ -42,7 +42,7 @@ def tower_func(image):
def run_test(path, input): def run_test(path, input):
param_dict = dict(np.load(path)) param_dict = dict(np.load(path))
predictor = OfflinePredictor(PredictConfig( predictor = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 227, 227, 3), 'input')], input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')],
tower_func=tower_func, tower_func=tower_func,
session_init=DictRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
......
...@@ -97,7 +97,7 @@ def CPM(image): ...@@ -97,7 +97,7 @@ def CPM(image):
def run_test(model_path, img_file): def run_test(model_path, img_file):
param_dict = dict(np.load(model_path)) param_dict = dict(np.load(model_path))
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 368, 368, 3), 'input')], input_signature=[tf.TensorSpec((None, 368, 368, 3), tf.float32, 'input')],
tower_func=CPM, tower_func=CPM,
session_init=DictRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
......
...@@ -59,7 +59,7 @@ def run_test(path, input): ...@@ -59,7 +59,7 @@ def run_test(path, input):
param_dict = {k.replace('/W', '/kernel').replace('/b', '/bias'): v for k, v in six.iteritems(param_dict)} param_dict = {k.replace('/W', '/kernel').replace('/b', '/bias'): v for k, v in six.iteritems(param_dict)}
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')], input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func, tower_func=tower_func,
session_init=DictRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
......
...@@ -62,7 +62,7 @@ def run_test(path, input): ...@@ -62,7 +62,7 @@ def run_test(path, input):
param_dict = {k.replace('/W', '/kernel').replace('/b', '/bias'): v for k, v in six.iteritems(param_dict)} param_dict = {k.replace('/W', '/kernel').replace('/b', '/bias'): v for k, v in six.iteritems(param_dict)}
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')], input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func, tower_func=tower_func,
session_init=DictRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
......
...@@ -88,7 +88,7 @@ class GANTrainer(TowerTrainer): ...@@ -88,7 +88,7 @@ class GANTrainer(TowerTrainer):
input = StagingInput(input) input = StagingInput(input)
# Setup input # Setup input
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_input_signature())
self.register_callback(cbs) self.register_callback(cbs)
if num_gpu <= 1: if num_gpu <= 1:
...@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer): ...@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK. not needed. Just calling model.build_graph directly is OK.
""" """
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
...@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer): ...@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
model.build_graph(*inputs) model.build_graph(*inputs)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(get_cost, model.get_input_signature())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers( cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)), list(range(num_gpu)),
...@@ -163,11 +163,11 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -163,11 +163,11 @@ class SeparateGANTrainer(TowerTrainer):
assert min(d_period, g_period) == 1 assert min(d_period, g_period) == 1
# Setup input # Setup input
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_input_signature())
self.register_callback(cbs) self.register_callback(cbs)
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True), \ with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True): argscope(BatchNorm, internal_update=True):
# should not hook the updates to both train_op, it will hurt training speed. # should not hook the updates to both train_op, it will hurt training speed.
......
...@@ -254,14 +254,11 @@ if __name__ == '__main__': ...@@ -254,14 +254,11 @@ if __name__ == '__main__':
eval_on_ILSVRC12(model, get_model_loader(args.load), ds) eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
elif args.flops: elif args.flops:
# manually build the graph with batch=1 # manually build the graph with batch=1
input_desc = [
InputDesc(tf.float32, [1, 224, 224, 3], 'input'),
InputDesc(tf.int32, [1], 'label')
]
input = PlaceholderInput()
input.setup(input_desc)
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
model.build_graph(*input.get_input_tensors()) model.build_graph(
tf.placeholder(tf.float32, [1, 224, 224, 3], 'input'),
tf.placeholder(tf.int32, [1], 'label')
)
model_utils.describe_trainable_vars() model_utils.describe_trainable_vars()
tf.profiler.profile( tf.profiler.profile(
......
...@@ -64,7 +64,7 @@ class Model(ModelDesc): ...@@ -64,7 +64,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
correct = tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32, name='correct') correct = tf.cast(tf.nn.in_top_k(predictions=logits, targets=label, k=1), tf.float32, name='correct')
# monitor training error # monitor training error
add_moving_summary(tf.reduce_mean(correct, name='accuracy')) add_moving_summary(tf.reduce_mean(correct, name='accuracy'))
...@@ -76,7 +76,7 @@ class Model(ModelDesc): ...@@ -76,7 +76,7 @@ class Model(ModelDesc):
return tf.add_n([cost, wd_cost], name='cost') return tf.add_n([cost, wd_cost], name='cost')
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=1e-2, trainable=False) lr = tf.Variable(1e-2, name='learning_rate', trainable=False)
tf.summary.scalar('lr', lr) tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr, epsilon=1e-3) return tf.train.AdamOptimizer(lr, epsilon=1e-3)
......
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.layers import * from tensorflow.python.keras.layers import *
from tensorpack import InputDesc, SyncMultiGPUTrainerReplicated from tensorpack import SyncMultiGPUTrainerReplicated
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.contrib.keras import KerasModel from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import FakeData, MapDataComponent from tensorpack.dataflow import FakeData, MapDataComponent
...@@ -166,8 +166,8 @@ if __name__ == '__main__': ...@@ -166,8 +166,8 @@ if __name__ == '__main__':
M = KerasModel( M = KerasModel(
resnet50, resnet50,
inputs_desc=[InputDesc(tf.uint8, [None, 224, 224, 3], 'images')], input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.uint8, 'images')],
targets_desc=[InputDesc(tf.float32, [None, 1000], 'labels')], target_signature=[tf.TensorSpec([None, 1000], tf.float32, 'labels')],
input=df_train, input=df_train,
trainer=SyncMultiGPUTrainerReplicated(num_gpu)) trainer=SyncMultiGPUTrainerReplicated(num_gpu))
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import keras from tensorflow import keras
from tensorpack import InputDesc, QueueInput from tensorpack import QueueInput
from tensorpack.callbacks import ModelSaver from tensorpack.callbacks import ModelSaver
from tensorpack.contrib.keras import KerasModel from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import BatchData, MapData, dataset from tensorpack.dataflow import BatchData, MapData, dataset
...@@ -57,8 +57,8 @@ if __name__ == '__main__': ...@@ -57,8 +57,8 @@ if __name__ == '__main__':
M = KerasModel( M = KerasModel(
model_func, model_func,
inputs_desc=[InputDesc(tf.float32, [None, IMAGE_SIZE, IMAGE_SIZE, 1], 'images')], input_signature=[tf.TensorSpec([None, IMAGE_SIZE, IMAGE_SIZE, 1], tf.float32, 'images')],
targets_desc=[InputDesc(tf.float32, [None, 10], 'labels')], target_signature=[tf.TensorSpec([None, 10], tf.float32, 'labels')],
input=QueueInput(dataset_train)) input=QueueInput(dataset_train))
M.compile( M.compile(
optimizer=tf.train.AdamOptimizer(1e-3), optimizer=tf.train.AdamOptimizer(1e-3),
......
...@@ -141,7 +141,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -141,7 +141,7 @@ class InferenceRunner(InferenceRunnerBase):
if self._tower_func is None: if self._tower_func is None:
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!" assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
self._tower_func = self.trainer.tower_func self._tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(self._tower_func.inputs_desc) input_callbacks = self._input_source.setup(self._tower_func.input_signature)
vs_name = self.trainer._vs_name_for_predictor(self._device_id) vs_name = self.trainer._vs_name_for_predictor(self._device_id)
logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format( logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
...@@ -223,7 +223,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -223,7 +223,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!" assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
self._tower_func = self.trainer.tower_func self._tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(self._tower_func.inputs_desc) input_callbacks = self._input_source.setup(self._tower_func.input_signature)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, dev in enumerate(self._devices): for idx, dev in enumerate(self._devices):
vs_name = self.trainer._vs_name_for_predictor(idx) vs_name = self.trainer._vs_name_for_predictor(idx)
......
...@@ -8,8 +8,8 @@ import numpy as np ...@@ -8,8 +8,8 @@ import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import deque from collections import deque
import six import six
import tensorflow as tf
from ..compat import tfv1
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from .base import Callback from .base import Callback
...@@ -67,7 +67,7 @@ class GraphVarParam(HyperParam): ...@@ -67,7 +67,7 @@ class GraphVarParam(HyperParam):
def setup_graph(self): def setup_graph(self):
""" Will setup the assign operator for that variable. """ """ Will setup the assign operator for that variable. """
all_vars = tf.global_variables() + tf.local_variables() all_vars = tfv1.global_variables() + tfv1.local_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
self.var = v self.var = v
......
...@@ -141,7 +141,7 @@ class KerasPhaseCallback(Callback): ...@@ -141,7 +141,7 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer( def setup_keras_trainer(
trainer, get_model, trainer, get_model,
inputs_desc, targets_desc, input_signature, target_signature,
input, optimizer, loss, metrics): input, optimizer, loss, metrics):
""" """
Args: Args:
...@@ -159,7 +159,7 @@ def setup_keras_trainer( ...@@ -159,7 +159,7 @@ def setup_keras_trainer(
assert isinstance(metrics, list), metrics assert isinstance(metrics, list), metrics
model_caller = KerasModelCaller(get_model) model_caller = KerasModelCaller(get_model)
nr_inputs = len(inputs_desc) nr_inputs = len(input_signature)
def get_cost(*inputs): def get_cost(*inputs):
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -211,7 +211,7 @@ def setup_keras_trainer( ...@@ -211,7 +211,7 @@ def setup_keras_trainer(
return total_loss return total_loss
trainer.setup_graph( trainer.setup_graph(
inputs_desc + targets_desc, input_signature + target_signature,
input, input,
get_cost, get_cost,
lambda: optimizer) lambda: optimizer)
...@@ -221,23 +221,27 @@ def setup_keras_trainer( ...@@ -221,23 +221,27 @@ def setup_keras_trainer(
class KerasModel(object): class KerasModel(object):
def __init__(self, get_model, inputs_desc, targets_desc, def __init__(self, get_model, input_signature=None, target_signature=None,
input, trainer=None): input=None, trainer=None, inputs_desc=None, targets_desc=None):
""" """
Args: Args:
get_model (input1, input2, ... -> keras.Model): get_model (input1, input2, ... -> keras.Model):
A function which takes tensors, builds and returns a Keras model. A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function. It will be part of the tower function.
inputs_desc ([InputDesc]): input_signature ([tf.TensorSpec]): required. The signature for inputs.
targets_desc ([InputDesc]): target_signature ([tf.TensorSpec]): required. The signature for the targets tensors.
input (InputSource | DataFlow): input (InputSource | DataFlow): the InputSource or DataFlow where the input data comes from.
trainer (Trainer): the default will check the number of available trainer (Trainer): the default will check the number of available GPUs and use them all.
GPUs and use them all. inputs_desc, targets_desc: deprecated names for `input_signature` and `target_signature`
""" """
if inputs_desc is not None:
input_signature = inputs_desc
if targets_desc is not None:
target_signature = targets_desc
self.get_model = get_model self.get_model = get_model
assert callable(get_model), get_model assert callable(get_model), get_model
self.inputs_desc = inputs_desc self.input_signature = input_signature
self.targets_desc = targets_desc self.target_signature = target_signature
if trainer is None: if trainer is None:
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu <= 1: if nr_gpu <= 1:
...@@ -248,6 +252,7 @@ class KerasModel(object): ...@@ -248,6 +252,7 @@ class KerasModel(object):
assert isinstance(trainer, Trainer), trainer assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase) assert not isinstance(trainer, DistributedTrainerBase)
assert input is not None, "Argument 'input' is required!"
self.input = apply_default_prefetch(input, trainer) self.input = apply_default_prefetch(input, trainer)
self.trainer = trainer self.trainer = trainer
...@@ -267,7 +272,8 @@ class KerasModel(object): ...@@ -267,7 +272,8 @@ class KerasModel(object):
self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME] self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME]
setup_keras_trainer( setup_keras_trainer(
self.trainer, get_model=self.get_model, self.trainer, get_model=self.get_model,
inputs_desc=self.inputs_desc, targets_desc=self.targets_desc, input_signature=self.input_signature,
target_signature=self.target_signature,
input=self.input, input=self.input,
optimizer=optimizer, optimizer=optimizer,
loss=loss, loss=loss,
......
...@@ -18,12 +18,41 @@ TensorSpec = backport_tensor_spec() ...@@ -18,12 +18,41 @@ TensorSpec = backport_tensor_spec()
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
def build_or_reuse_placeholder(tensor_spec):
"""
Build a tf.placeholder from the metadata in the given tensor spec, or return an existing one.
Args:
tensor_spec (tf.TensorSpec):
Returns:
tf.Tensor:
"""
g = tfv1.get_default_graph()
name = tensor_spec.name
try:
tensor = g.get_tensor_by_name(name + ':0')
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor)
return tensor
except KeyError:
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
tensor_spec.dtype, shape=tensor_spec.shape, name=tensor_spec.name)
return ret
class InputDesc( class InputDesc(
namedtuple('InputDescTuple', ['type', 'shape', 'name'])): namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
""" """
Metadata about an input entry point to the graph. An equivalent of `tf.TensorSpec`.
This metadata can be later used to build placeholders or other types of
input source. History: this concept is used to represent metadata about the inputs,
which can be later used to build placeholders or other types of input source.
It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
was introduced in TensorFlow.
Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
""" """
def __new__(cls, type, shape, name): def __new__(cls, type, shape, name):
...@@ -33,64 +62,9 @@ class InputDesc( ...@@ -33,64 +62,9 @@ class InputDesc(
shape (tuple): shape (tuple):
name (str): name (str):
""" """
shape = tuple(shape) # has to be tuple for "self" to be hashable # TODO mark deprecated
assert isinstance(type, tf.DType), type assert isinstance(type, tf.DType), type
if any(k in name for k in [':', '/', ' ']): return tf.TensorSpec(shape=shape, dtype=type, name=name)
raise ValueError("Invalid InputDesc name: '{}'".format(name))
self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = {}
return self
def _build_placeholder(self):
"""
Build a tf.placeholder from the metadata.
Returns:
tf.Tensor:
"""
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
self.type, shape=self.shape, name=self.name)
self._register_cached_placeholder(ret)
return ret
# cannot memoize here, because InputDesc is hashed by its fields.
def build_placeholder_reuse(self):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
g = tfv1.get_default_graph()
if g in self._cached_placeholder:
return self._cached_placeholder[g]
else:
return self._build_placeholder()
def _register_cached_placeholder(self, placeholder):
graph = placeholder.graph
assert graph not in self._cached_placeholder, \
"Placeholder for this InputDesc had been created before! This is a bug."
self._cached_placeholder[graph] = placeholder
@staticmethod
def _from_placeholder(placeholder):
name = placeholder.op.name
if name.endswith('_1') or name.endswith('_2'):
logger.error("Creating InputDesc from a placeholder named {}.".format(name))
logger.error("You might have mistakenly created this placeholder multiple times!")
ret = InputDesc(
placeholder.dtype,
tuple(placeholder.shape.as_list()),
name)
ret._register_cached_placeholder(placeholder)
return ret
@staticmethod
def _from_tensor_spec(spec):
assert spec.name is not None, "TensorSpec should have a name!"
return InputDesc(spec.dtype, tuple(spec.shape.as_list()), spec.name)
class ModelDescBase(object): class ModelDescBase(object):
...@@ -100,29 +74,22 @@ class ModelDescBase(object): ...@@ -100,29 +74,22 @@ class ModelDescBase(object):
@memoized_method @memoized_method
def get_inputs_desc(self): def get_inputs_desc(self):
# TODO mark deprecated
return self.get_input_signature()
@memoized_method
def get_input_signature(self):
""" """
Returns: Returns:
A list of :class:`InputDesc`, which describes the inputs of this model. A list of :class:`tf.TensorSpec`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`. The result is cached for each instance of :class:`ModelDescBase`.
""" """
try: with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
ret = self._get_inputs() inputs = self.inputs()
log_deprecated( if isinstance(inputs[0], tf.Tensor):
"ModelDescBase._get_inputs() interface", for p in inputs:
"Use inputs() instead!", assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
"2019-03-30") return [TensorSpec(shape=p.shape, dtype=p.dtype, name=p.name) for p in inputs]
return ret
except NotImplementedError:
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
if isinstance(inputs[0], tf.Tensor):
for p in inputs:
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [InputDesc._from_placeholder(p) for p in inputs]
else:
for p in inputs:
assert isinstance(p, TensorSpec), type(p)
return [InputDesc._from_tensor_spec(p) for p in inputs]
@property @property
def input_names(self): def input_names(self):
...@@ -130,7 +97,7 @@ class ModelDescBase(object): ...@@ -130,7 +97,7 @@ class ModelDescBase(object):
Returns: Returns:
[str]: the names of all the inputs. [str]: the names of all the inputs.
""" """
return [k.name for k in self.get_inputs_desc()] return [k.name for k in self.get_input_signature()]
def _get_inputs(self): def _get_inputs(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -147,7 +114,7 @@ class ModelDescBase(object): ...@@ -147,7 +114,7 @@ class ModelDescBase(object):
Also, you should never call this method by yourself. Also, you should never call this method by yourself.
Returns: Returns:
list[tf.placeholder] or list[tf.TensorSpec], to be converted to :class:`InputDesc`. list[tf.TensorSpec or tf.placeholder]. To be converted to :class:`tf.TensorSpec`.
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -166,9 +133,9 @@ class ModelDescBase(object): ...@@ -166,9 +133,9 @@ class ModelDescBase(object):
may require it to return necessary information to build the trainer. may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor. For example, `SingleCostTrainer` expect this method to return the cost tensor.
""" """
assert len(args) == len(self.get_inputs_desc()), \ assert len(args) == len(self.get_input_signature()), \
"Number of inputs passed to the graph != number of inputs defined " \ "Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(args), len(self.get_inputs_desc())) "in ModelDesc! ({} != {})".format(len(args), len(self.get_input_signature()))
log_deprecated( log_deprecated(
"ModelDescBase._build_graph() interface", "ModelDescBase._build_graph() interface",
"Use build_graph() instead!", "Use build_graph() instead!",
......
...@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context ...@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from .input_source_base import InputSource from .input_source_base import InputSource
from ..graph_builder.model_desc import build_or_reuse_placeholder
try: try:
from tensorflow.python.ops.data_flow_ops import StagingArea from tensorflow.python.ops.data_flow_ops import StagingArea
...@@ -59,7 +60,7 @@ class PlaceholderInput(InputSource): ...@@ -59,7 +60,7 @@ class PlaceholderInput(InputSource):
pass pass
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
def _get_input_tensors(self): def _get_input_tensors(self):
return self._all_placehdrs return self._all_placehdrs
...@@ -110,7 +111,7 @@ class FeedInput(InputSource): ...@@ -110,7 +111,7 @@ class FeedInput(InputSource):
def _setup(self, inputs): def _setup(self, inputs):
# placeholders as input are always safe to reuse. # placeholders as input are always safe to reuse.
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs) self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
def _get_input_tensors(self): def _get_input_tensors(self):
...@@ -196,7 +197,7 @@ class QueueInput(FeedfreeInput): ...@@ -196,7 +197,7 @@ class QueueInput(FeedfreeInput):
Args: Args:
ds(DataFlow): the input DataFlow. ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model. should match the corresponding input signature of the model.
Defaults to a FIFO queue of size 50. Defaults to a FIFO queue of size 50.
""" """
if not isinstance(ds, DataFlow): if not isinstance(ds, DataFlow):
...@@ -210,12 +211,12 @@ class QueueInput(FeedfreeInput): ...@@ -210,12 +211,12 @@ class QueueInput(FeedfreeInput):
return len(self.ds) return len(self.ds)
def _setup(self, inputs): def _setup(self, inputs):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._input_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
assert len(self._input_placehdrs) > 0, \ assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!" "QueueInput has to be used with some inputs!"
with self.cached_name_scope(): with self.cached_name_scope():
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tfv1.FIFOQueue(
50, [x.dtype for x in self._input_placehdrs], 50, [x.dtype for x in self._input_placehdrs],
name='input_queue') name='input_queue')
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name)) logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
...@@ -287,7 +288,7 @@ class BatchQueueInput(QueueInput): ...@@ -287,7 +288,7 @@ class BatchQueueInput(QueueInput):
ds(DataFlow): the input DataFlow. ds(DataFlow): the input DataFlow.
batch_size(int): the batch size. batch_size(int): the batch size.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model. should match the corresponding input signature of the model.
Defaults to a FIFO queue of size 3000. Defaults to a FIFO queue of size 3000.
""" """
super(BatchQueueInput, self).__init__(ds, queue) super(BatchQueueInput, self).__init__(ds, queue)
...@@ -298,9 +299,9 @@ class BatchQueueInput(QueueInput): ...@@ -298,9 +299,9 @@ class BatchQueueInput(QueueInput):
def _setup(self, inputs): def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self.input_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with some InputDesc!" "BatchQueueInput has to be used with some input signature!"
# prepare placeholders without the first dimension # prepare placeholders without the first dimension
placehdrs_nobatch = [] placehdrs_nobatch = []
...@@ -364,8 +365,8 @@ class TensorInput(FeedfreeInput): ...@@ -364,8 +365,8 @@ class TensorInput(FeedfreeInput):
assert size > 0 assert size > 0
self._fixed_size = size self._fixed_size = size
def _setup(self, inputs_desc): def _setup(self, input_signature):
self._desc = inputs_desc self._spec = input_signature
def _size(self): def _size(self):
if self._fixed_size is None: if self._fixed_size is None:
...@@ -376,8 +377,8 @@ class TensorInput(FeedfreeInput): ...@@ -376,8 +377,8 @@ class TensorInput(FeedfreeInput):
with self.cached_name_scope(): with self.cached_name_scope():
ret = self.get_tensor_fn() ret = self.get_tensor_fn()
assert isinstance(ret, (list, tuple)), "get_tensor_fn needs to return a list!" assert isinstance(ret, (list, tuple)), "get_tensor_fn needs to return a list!"
assert len(ret) == len(self._desc), \ assert len(ret) == len(self._spec), \
"get_tensor_fn returns {} tensors but there are {} inputs".format(len(ret), len(self._desc)) "get_tensor_fn returns {} tensors but there are {} inputs".format(len(ret), len(self._spec))
return ret return ret
...@@ -399,7 +400,7 @@ class DummyConstantInput(TensorInput): ...@@ -399,7 +400,7 @@ class DummyConstantInput(TensorInput):
assert len(self.shapes) == len(self._desc) assert len(self.shapes) == len(self._desc)
for idx, p in enumerate(self._desc): for idx, p in enumerate(self._desc):
tlist.append(tf.constant( tlist.append(tf.constant(
0, dtype=p.type, 0, dtype=p.dtype,
name='dummy-{}-{}'.format(p.name, ctx.index), name='dummy-{}-{}'.format(p.name, ctx.index),
shape=self.shapes[idx])) shape=self.shapes[idx]))
return tlist return tlist
...@@ -429,15 +430,14 @@ class ZMQInput(TensorInput): ...@@ -429,15 +430,14 @@ class ZMQInput(TensorInput):
return ret return ret
super(ZMQInput, self).__init__(fn) super(ZMQInput, self).__init__(fn)
def _setup(self, inputs_desc): def _setup(self, input_signature):
assert len(inputs_desc) > 0, \ assert len(input_signature) > 0, \
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with input signature!"
self._desc = inputs_desc
import zmq_ops import zmq_ops
self._zmq_pull_socket = zmq_ops.ZMQPullSocket( self._zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point, self._end_point,
[x.type for x in inputs_desc], [x.dtype for x in input_signature],
hwm=self._hwm, hwm=self._hwm,
bind=self._bind) bind=self._bind)
...@@ -458,23 +458,23 @@ class TFDatasetInput(FeedfreeInput): ...@@ -458,23 +458,23 @@ class TFDatasetInput(FeedfreeInput):
raise ValueError("TFDatasetInput takes a tf.data.Dataset! Got {}".format(dataset)) raise ValueError("TFDatasetInput takes a tf.data.Dataset! Got {}".format(dataset))
self._dataset = dataset self._dataset = dataset
def _setup(self, inputs_desc): def _setup(self, input_signature):
self._desc = inputs_desc self._spec = input_signature
types = self._dataset.output_types types = self._dataset.output_types
desc_types = tuple([k.type for k in inputs_desc]) spec_types = tuple([k.dtype for k in input_signature])
assert len(types) == len(desc_types), \ assert len(types) == len(spec_types), \
"Dataset and InputDesc has different length! {} != {}".format( "Dataset and input signature have different length! {} != {}".format(
len(types), len(desc_types)) len(types), len(spec_types))
assert types == desc_types, \ assert types == spec_types, \
"Types of dataset and InputDesc don't match! {} != {}".format( "Data types of dataset and input signature don't match! {} != {}".format(
str(types), str(desc_types)) str(types), str(spec_types))
shapes = self._dataset.output_shapes shapes = self._dataset.output_shapes
desc_shapes = [k.shape for k in inputs_desc] spec_shapes = [k.shape for k in input_signature]
for idx, (s1, s2) in enumerate(zip(shapes, desc_shapes)): for idx, (s1, s2) in enumerate(zip(shapes, spec_shapes)):
s2 = tf.TensorShape(s2) s2 = tf.TensorShape(s2)
assert s2.is_compatible_with(s1), \ assert s2.is_compatible_with(s1), \
"InputDesc '{}' has incompatible shape with dataset! {} vs {}".format( "Input signature '{}' has incompatible shape with dataset! {} vs {}".format(
inputs_desc[idx].name, s2, s1) input_signature[idx].name, s2, s1)
self._iterator = self._dataset.make_initializable_iterator() self._iterator = self._dataset.make_initializable_iterator()
self._init_op = self._iterator.initializer self._init_op = self._iterator.initializer
...@@ -482,11 +482,11 @@ class TFDatasetInput(FeedfreeInput): ...@@ -482,11 +482,11 @@ class TFDatasetInput(FeedfreeInput):
self._init_op.run() self._init_op.run()
def _get_input_tensors(self): def _get_input_tensors(self):
desc_shapes = [k.shape for k in self._desc] spec_shapes = [k.shape for k in self._spec]
ret = self._iterator.get_next() ret = self._iterator.get_next()
assert len(ret) == len(desc_shapes), \ assert len(ret) == len(spec_shapes), \
"Dataset returns {} tensors but there are {} inputs!".format(len(ret), len(desc_shapes)) "Dataset returns {} tensors but there are {} inputs!".format(len(ret), len(spec_shapes))
for t, shp in zip(ret, desc_shapes): for t, shp in zip(ret, spec_shapes):
t.set_shape(shp) t.set_shape(shp)
return ret return ret
......
...@@ -12,6 +12,7 @@ from ..callbacks.base import CallbackFactory ...@@ -12,6 +12,7 @@ from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized_method from ..utils.argtools import call_only_once, memoized_method
from ..graph_builder.model_desc import build_or_reuse_placeholder
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
...@@ -86,20 +87,20 @@ class InputSource(object): ...@@ -86,20 +87,20 @@ class InputSource(object):
pass pass
@call_only_once @call_only_once
def setup(self, inputs_desc): def setup(self, input_signature):
""" """
Args: Args:
inputs_desc (list[InputDesc]): list of input desc input_signature (list[tf.TensorSpec]): list of specs for each input tensor
Returns: Returns:
list[Callback]: extra callbacks needed by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
callbacks of InputSource cannot use any `trigger*()` method. callbacks of InputSource cannot use any `trigger*()` method.
""" """
self._setup(inputs_desc) self._setup(input_signature)
self._setup_done = True self._setup_done = True
return self.get_callbacks() return self.get_callbacks()
def _setup(self, inputs_desc): def _setup(self, input_signature):
pass pass
def setup_done(self): def setup_done(self):
...@@ -190,8 +191,8 @@ class ProxyInputSource(InputSource): ...@@ -190,8 +191,8 @@ class ProxyInputSource(InputSource):
def _get_input_tensors(self): def _get_input_tensors(self):
return self._input.get_input_tensors() return self._input.get_input_tensors()
def _setup(self, inputs_desc): def _setup(self, input_signature):
self._input.setup(inputs_desc) self._input.setup(input_signature)
def _get_callbacks(self): def _get_callbacks(self):
return self._input.get_callbacks() return self._input.get_callbacks()
...@@ -226,11 +227,11 @@ def remap_input_source(input, names): ...@@ -226,11 +227,11 @@ def remap_input_source(input, names):
input1 = QueueInput(ds) input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more # assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order: # inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'), input_signature = [tf.TensorSpec((None,10), tf.float32, 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'), tf.TensorSpec((None,20,20,3), tf.float32, 'label'),
InputDesc(tf.int32, (None,), 'image') ] tf.TensorSpec((None,), tf.int32, 'image') ]
input2 = remap_input_source(input1, ['image', 'label']) input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc) input2.setup(input_signature)
# now, input2.get_input_tensors() will return a placeholder for 'score', # now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors() # plus the tensors returned by input1.get_input_tensors()
""" """
...@@ -240,7 +241,7 @@ def remap_input_source(input, names): ...@@ -240,7 +241,7 @@ def remap_input_source(input, names):
self._names = tuple(names) self._names = tuple(names)
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names) inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset) self._input.setup(inputs_subset)
......
...@@ -155,7 +155,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -155,7 +155,7 @@ class OfflinePredictor(OnlinePredictor):
self.graph = config._maybe_create_graph() self.graph = config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input = PlaceholderInput() input = PlaceholderInput()
input.setup(config.inputs_desc) input.setup(config.input_signature)
with PredictTowerContext(''): with PredictTowerContext(''):
config.tower_func(*input.get_input_tensors()) config.tower_func(*input.get_input_tensors())
......
...@@ -18,7 +18,7 @@ class PredictConfig(object): ...@@ -18,7 +18,7 @@ class PredictConfig(object):
def __init__(self, def __init__(self,
model=None, model=None,
tower_func=None, tower_func=None,
inputs_desc=None, input_signature=None,
input_names=None, input_names=None,
output_names=None, output_names=None,
...@@ -27,11 +27,18 @@ class PredictConfig(object): ...@@ -27,11 +27,18 @@ class PredictConfig(object):
session_init=None, session_init=None,
return_input=False, return_input=False,
create_graph=True, create_graph=True,
inputs_desc=None
): ):
""" """
You need to set either `model`, or `inputs_desc` plus `tower_func`. Users need to provide enough arguments to create a tower function,
They are needed to construct the graph. which will be used to construct the graph.
You'll also have to set `output_names` as it does not have a default. This can be provided in the following ways:
1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself.
2. `tower_func`: a :class:`tfutils.TowerFuncWrapper` instance.
Provide a tower function instance directly.
3. `tower_func`: a symbolic function and `input_signature`: the signature of the function.
Provide both a function and its signature.
Example: Example:
...@@ -42,15 +49,14 @@ class PredictConfig(object): ...@@ -42,15 +49,14 @@ class PredictConfig(object):
output_names=['linear/output', 'prediction']) output_names=['linear/output', 'prediction'])
Args: Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func. model (ModelDescBase): to be used to construct a tower function.
tower_func: a callable which takes input tensors (by positional args) and construct a tower. tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance, which packs both `inputs_desc` and function together. or a :class:`tfutils.TowerFuncWrapper` instance.
inputs_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFuncWrapper),
the list of inputs it takes. this describes the list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match inputs_desc. input_names (list): a list of input tensor names. Defaults to match input_signature.
The name can be either the name of a tensor, or the name of one input defined The name can be either the name of a tensor, or the name of one input of the tower.
by `inputs_desc` or by `model`.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`. tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`.
...@@ -62,23 +68,29 @@ class PredictConfig(object): ...@@ -62,23 +68,29 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`. return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized. when predictor is first initialized.
inputs_desc (list[tf.TensorSpec]): old (deprecated) name for `input_signature`.
""" """
def assert_type(v, tp, name): def assert_type(v, tp, name):
assert isinstance(v, tp), \ assert isinstance(v, tp), \
"{} has to be type '{}', but an object of type '{}' found.".format( "Argument '{}' has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__) name, tp.__name__, v.__class__.__name__)
if inputs_desc is not None:
# TODO warn deprecated or not?
assert input_signature is None, "Cannot set both inputs_desc and input_signature!"
input_signature = inputs_desc
if model is not None: if model is not None:
assert_type(model, ModelDescBase, 'model') assert_type(model, ModelDescBase, 'model')
assert inputs_desc is None and tower_func is None assert input_signature is None and tower_func is None
self.inputs_desc = model.get_inputs_desc() self.input_signature = model.get_input_signature()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc) self.tower_func = TowerFuncWrapper(model.build_graph, self.input_signature)
else: else:
if isinstance(tower_func, TowerFuncWrapper): if isinstance(tower_func, TowerFuncWrapper):
inputs_desc = tower_func.inputs_desc input_signature = tower_func.input_signature
assert inputs_desc is not None and tower_func is not None assert input_signature is not None and tower_func is not None
self.inputs_desc = inputs_desc self.input_signature = input_signature
self.tower_func = TowerFuncWrapper(tower_func, inputs_desc) self.tower_func = TowerFuncWrapper(tower_func, input_signature)
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
...@@ -93,20 +105,22 @@ class PredictConfig(object): ...@@ -93,20 +105,22 @@ class PredictConfig(object):
# inputs & outputs # inputs & outputs
self.input_names = input_names self.input_names = input_names
if self.input_names is None: if self.input_names is None:
self.input_names = [k.name for k in self.inputs_desc] self.input_names = [k.name for k in self.input_signature]
assert output_names is not None, "Argument 'output_names' is not provided!"
self.output_names = output_names self.output_names = output_names
assert_type(self.output_names, list, 'output_names') assert_type(self.output_names, list, 'output_names')
assert_type(self.input_names, list, 'input_names') assert_type(self.input_names, list, 'input_names')
if len(self.input_names) == 0: if len(self.input_names) == 0:
logger.warn('PredictConfig receives empty "input_names".') logger.warn('PredictConfig receives empty "input_names".')
# assert len(self.input_names), self.input_names
for v in self.input_names: for v in self.input_names:
assert_type(v, six.string_types, 'Each item in input_names') assert_type(v, six.string_types, 'Each item in input_names')
assert len(self.output_names), self.output_names assert len(self.output_names), "Argument 'output_names' cannot be empty!"
self.return_input = bool(return_input) self.return_input = bool(return_input)
self.create_graph = bool(create_graph) self.create_graph = bool(create_graph)
self.inputs_desc = input_signature # TODO a little bit of compatibility
def _maybe_create_graph(self): def _maybe_create_graph(self):
if self.create_graph: if self.create_graph:
return tf.Graph() return tf.Graph()
......
...@@ -21,7 +21,7 @@ class FeedfreePredictor(PredictorBase): ...@@ -21,7 +21,7 @@ class FeedfreePredictor(PredictorBase):
Args: Args:
config (PredictConfig): the config to use. config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use. input_source (InputSource): the feedfree InputSource to use.
Must match the inputs_desc in config. Must match the signature of the tower function in config.
""" """
self._config = config self._config = config
self._input_source = input_source self._input_source = input_source
...@@ -33,7 +33,7 @@ class FeedfreePredictor(PredictorBase): ...@@ -33,7 +33,7 @@ class FeedfreePredictor(PredictorBase):
self.graph = config._maybe_create_graph() self.graph = config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
self._input_callbacks = Callbacks( self._input_callbacks = Callbacks(
self._input_source.setup(config.inputs_desc)) self._input_source.setup(config.input_signature))
with PredictTowerContext(''): with PredictTowerContext(''):
self._input_tensors = self._input_source.get_input_tensors() self._input_tensors = self._input_source.get_input_tensors()
config.tower_func(*self._input_tensors) config.tower_func(*self._input_tensors)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import tensorflow as tf import tensorflow as tf
from ..graph_builder.model_desc import InputDesc
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..utils import logger from ..utils import logger
...@@ -33,7 +32,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -33,7 +32,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
handles = [] handles = []
input = PlaceholderInput() input = PlaceholderInput()
input.setup(config.inputs_desc) input.setup(config.input_signature)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = 'tower' + str(t) tower_name = 'tower' + str(t)
...@@ -102,10 +101,10 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -102,10 +101,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = 'tower' + str(t) tower_name = 'tower' + str(t)
inputs_desc = [InputDesc(desc.type, desc.shape, tower_name + '_' + desc.name) new_sig = [tf.TensorSpec(dtype=p.dtype, shape=p.shape, name=tower_name + '_' + p.name)
for desc in config.inputs_desc] for p in config.input_signature]
input = PlaceholderInput() input = PlaceholderInput()
input.setup(inputs_desc) input.setup(new_sig)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0), \ with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0), \
tf.device('/gpu:{}'.format(t)), \ tf.device('/gpu:{}'.format(t)), \
......
...@@ -28,7 +28,7 @@ class ModelExporter(object): ...@@ -28,7 +28,7 @@ class ModelExporter(object):
Args: Args:
config (PredictConfig): the config to use. config (PredictConfig): the config to use.
The graph will be built with `config.tower_func` and `config.inputs_desc`. The graph will be built with the tower function defined by this `PredictConfig`.
Then the input / output names will be used to export models for inference. Then the input / output names will be used to export models for inference.
""" """
super(ModelExporter, self).__init__() super(ModelExporter, self).__init__()
...@@ -51,7 +51,7 @@ class ModelExporter(object): ...@@ -51,7 +51,7 @@ class ModelExporter(object):
self.graph = self.config._maybe_create_graph() self.graph = self.config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.config.inputs_desc) input.setup(self.config.input_signature)
with PredictTowerContext(''): with PredictTowerContext(''):
self.config.tower_func(*input.get_input_tensors()) self.config.tower_func(*input.get_input_tensors())
...@@ -116,7 +116,7 @@ class ModelExporter(object): ...@@ -116,7 +116,7 @@ class ModelExporter(object):
self.graph = self.config._maybe_create_graph() self.graph = self.config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.config.inputs_desc) input.setup(self.config.input_signature)
with PredictTowerContext(''): with PredictTowerContext(''):
self.config.tower_func(*input.get_input_tensors()) self.config.tower_func(*input.get_input_tensors())
......
...@@ -257,24 +257,27 @@ class TowerFuncWrapper(object): ...@@ -257,24 +257,27 @@ class TowerFuncWrapper(object):
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0. Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
""" """
def __init__(self, tower_fn, inputs_desc): def __init__(self, tower_fn, input_signature):
""" """
Args: Args:
tower_func: a function which builds one tower in the graph. tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything. It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): list of :class:`InputDesc`. input_signature ([TensorSpec]): list of :class:`tf.TensorSpec`.
They are used to figure out the names for the input tensors. They are used to figure out the names for the input tensors.
""" """
assert callable(tower_fn), tower_fn assert callable(tower_fn), tower_fn
self._inputs_desc_names = [k.name for k in inputs_desc] self._inputs_names = [k.name for k in input_signature]
assert len(set(self._inputs_desc_names)) == len(self._inputs_desc_names), \ assert len(set(self._inputs_names)) == len(self._inputs_names), \
"Duplicated names in inputs_desc! " + str(self._inputs_desc_names) "Duplicated names in input_signature! " + str(self._inputs_names)
for name in self._inputs_names:
if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid input name: '{}'".format(name))
self._tower_fn = tower_fn self._tower_fn = tower_fn
self._inputs_desc = inputs_desc self._input_signature = input_signature
self._handles = [] self._handles = []
def __new__(cls, tower_fn, inputs_desc): def __new__(cls, tower_fn, _):
# to avoid double-wrapping a function # to avoid double-wrapping a function
if isinstance(tower_fn, TowerFuncWrapper): if isinstance(tower_fn, TowerFuncWrapper):
return tower_fn return tower_fn
...@@ -285,7 +288,7 @@ class TowerFuncWrapper(object): ...@@ -285,7 +288,7 @@ class TowerFuncWrapper(object):
ctx = get_current_tower_context() ctx = get_current_tower_context()
assert ctx is not None, "Function must be called under TowerContext!" assert ctx is not None, "Function must be called under TowerContext!"
output = self._tower_fn(*args) output = self._tower_fn(*args)
handle = TowerTensorHandle(ctx, args, output, self._inputs_desc) handle = TowerTensorHandle(ctx, args, output, self._input_signature)
self._handles.append(handle) self._handles.append(handle)
return output return output
...@@ -298,9 +301,14 @@ class TowerFuncWrapper(object): ...@@ -298,9 +301,14 @@ class TowerFuncWrapper(object):
""" """
return TowerTensorHandles(self._handles) return TowerTensorHandles(self._handles)
@property
def input_signature(self):
return self._input_signature
@property @property
def inputs_desc(self): def inputs_desc(self):
return self._inputs_desc # TODO mark deprecated
return self._input_signature
class TowerTensorHandles(object): class TowerTensorHandles(object):
...@@ -354,14 +362,14 @@ class TowerTensorHandle(object): ...@@ -354,14 +362,14 @@ class TowerTensorHandle(object):
""" """
@HIDE_DOC @HIDE_DOC
def __init__(self, ctx, input, output, inputs_desc=None): def __init__(self, ctx, input, output, input_signature=None):
self._ctx = ctx self._ctx = ctx
self._extra_tensor_names = {} self._extra_tensor_names = {}
if inputs_desc is not None: if input_signature is not None:
assert len(inputs_desc) == len(input) assert len(input_signature) == len(input)
self._extra_tensor_names = { self._extra_tensor_names = {
get_op_tensor_name(x.name)[1]: y for x, y in zip(inputs_desc, input)} get_op_tensor_name(x.name)[1]: y for x, y in zip(input_signature, input)}
self._input = input self._input = input
self._output = output self._output = output
...@@ -379,7 +387,7 @@ class TowerTensorHandle(object): ...@@ -379,7 +387,7 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix. 1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower. 2. A name in the input signature, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue). input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
......
...@@ -87,7 +87,7 @@ def launch_train_with_config(config, trainer): ...@@ -87,7 +87,7 @@ def launch_train_with_config(config, trainer):
# We should gradually stay away from this unuseful abstraction. # We should gradually stay away from this unuseful abstraction.
# TowerFuncWrapper is a better abstraction (similar to tf.defun in the future) # TowerFuncWrapper is a better abstraction (similar to tf.defun in the future)
trainer.setup_graph( trainer.setup_graph(
model.get_inputs_desc(), input, model.get_input_signature(), input,
model._build_graph_get_cost, model.get_optimizer) model._build_graph_get_cost, model.get_optimizer)
_check_unused_regularization() _check_unused_regularization()
trainer.train_with_defaults( trainer.train_with_defaults(
......
...@@ -56,11 +56,16 @@ class TowerTrainer(Trainer): ...@@ -56,11 +56,16 @@ class TowerTrainer(Trainer):
@property @property
def inputs_desc(self): def inputs_desc(self):
# TODO mark deprecated
return self.input_signature
@property
def input_signature(self):
""" """
Returns: Returns:
list[InputDesc]: metainfo about the inputs to the tower. list[tf.TensorSpec]: metainfo about the inputs to the tower.
""" """
return self.tower_func.inputs_desc return self.tower_func.input_signature
@property @property
def towers(self): def towers(self):
...@@ -124,7 +129,7 @@ class TowerTrainer(Trainer): ...@@ -124,7 +129,7 @@ class TowerTrainer(Trainer):
if tower is None: if tower is None:
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.inputs_desc) input.setup(self.input_signature)
vs_name = self._vs_name_for_predictor(device_id) vs_name = self._vs_name_for_predictor(device_id)
with tfv1.variable_scope(tfv1.get_variable_scope(), reuse=True), \ with tfv1.variable_scope(tfv1.get_variable_scope(), reuse=True), \
...@@ -164,7 +169,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -164,7 +169,7 @@ class SingleCostTrainer(TowerTrainer):
Base class for single-cost trainer. Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training graph from them. (input_signature, input, get_cost_fn, get_opt_fn), and build the training graph from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`. To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
""" """
...@@ -194,12 +199,12 @@ class SingleCostTrainer(TowerTrainer): ...@@ -194,12 +199,12 @@ class SingleCostTrainer(TowerTrainer):
""" """
@call_only_once @call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn): def setup_graph(self, input_signature, input, get_cost_fn, get_opt_fn):
""" """
Responsible for building the main training graph for single-cost training. Responsible for building the main training graph for single-cost training.
Args: Args:
inputs_desc ([InputDesc]): input_signature ([TensorSpec]): list of TensorSpec that describe the inputs
input (InputSource): input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor. get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an get_opt_fn (-> tf.train.Optimizer): callable which returns an
...@@ -210,12 +215,12 @@ class SingleCostTrainer(TowerTrainer): ...@@ -210,12 +215,12 @@ class SingleCostTrainer(TowerTrainer):
It must follows the `rules of tower function. It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_. <http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
""" """
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc) get_cost_fn = TowerFuncWrapper(get_cost_fn, input_signature)
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
self.tower_func = get_cost_fn self.tower_func = get_cost_fn
# TODO setup may want to register monitor as well?? # TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(input_signature, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self.register_callback(input_callbacks + train_callbacks) self.register_callback(input_callbacks + train_callbacks)
...@@ -229,9 +234,9 @@ class SingleCostTrainer(TowerTrainer): ...@@ -229,9 +234,9 @@ class SingleCostTrainer(TowerTrainer):
[Callback]: list of callbacks needed [Callback]: list of callbacks needed
""" """
def _setup_input(self, inputs_desc, input): def _setup_input(self, input_signature, input):
assert not input.setup_done() assert not input.setup_done()
return input.setup(inputs_desc) return input.setup(input_signature)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn): def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
""" """
......
...@@ -272,7 +272,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase): ...@@ -272,7 +272,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
self._builder = DistributedReplicatedBuilder(gpus, server) self._builder = DistributedReplicatedBuilder(gpus, server)
self.is_chief = self._builder.is_chief self.is_chief = self._builder.is_chief
def _setup_input(self, inputs_desc, input): def _setup_input(self, input_signature, input):
with override_to_local_variable(): with override_to_local_variable():
get_global_step_var() # gs should be local get_global_step_var() # gs should be local
# input source may create variable (queue size summary) # input source may create variable (queue size summary)
...@@ -280,7 +280,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase): ...@@ -280,7 +280,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
# whether something should be global or local. We now assume # whether something should be global or local. We now assume
# they should be local. # they should be local.
assert not input.setup_done() assert not input.setup_done()
return input.setup(inputs_desc) return input.setup(input_signature)
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
......
...@@ -132,8 +132,8 @@ class ShareSessionThread(threading.Thread): ...@@ -132,8 +132,8 @@ class ShareSessionThread(threading.Thread):
yield None yield None
def start(self): def start(self):
import tensorflow as tf from ..compat import tfv1
self._sess = tf.get_default_session() self._sess = tfv1.get_default_session()
super(ShareSessionThread, self).start() super(ShareSessionThread, self).start()
def run(self): def run(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment