Commit b8a50d72 authored by Yuxin Wu's avatar Yuxin Wu

InputDesc -> tf.TensorSpec everywhere

parent ba679ab1
......@@ -8,6 +8,9 @@ so you don't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+ [2019/03/20] The concept of `InputDesc` was replaced by its equivalent in TF:
`tf.TensorSpec`. This may be a breaking change if you have customized
code that relies on internals of `InputDesc`.
+ [2018/08/27] msgpack is used again for "serialization to disk", because pyarrow
has no compatibility between versions. To use pyarrow instead, `export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`.
+ [2018/04/05] msgpack is replaced by pyarrow in favor of its speed. If you want old behavior,
......
......@@ -375,6 +375,8 @@ _DEPRECATED_NAMES = set([
'PrefetchOnGPUs',
'DistributedTrainerReplicated',
'DistributedTrainerParameterServer',
'InputDesc',
'inputs_desc',
# renamed items that should not appear in docs
'DumpTensor',
......
......@@ -48,7 +48,7 @@ Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will take care of step 1 (define the graph), with the following arguments:
1. Some `InputDesc`, the metadata about the input.
1. Some `tf.TensorSpec`, the signature of the input.
2. An `InputSource`, where the input come from. See [Input Pipeline](input-source.html).
3. A function which takes input tensors and returns the cost.
4. A function which returns an optimizer.
......
......@@ -11,7 +11,7 @@ This interface is enough for most types of single-cost tasks.
A lot of examples are written in this interface.
[SingleCost trainers](../modules/train.html#tensorpack.train.SingleCostTrainer)
expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost function, and an optimizer.
expects 4 arguments to setup the graph: input signatures, `InputSource`, get_cost function, and an optimizer.
`ModelDesc` describes a model by packing 3 of them together into one object:
```python
......@@ -62,7 +62,7 @@ The function `launch_train_with_config(config, trainer)`
uses the raw trainer interface under the hood, and is almost equivalent to the following two lines of code:
```python
trainer.setup_graph(
my_model.get_inputs_desc(),
my_model.get_input_signature(),
my_input_source, # or QueueInput(my_dataflow)
my_model.build_graph,
my_model.get_optimizer)
......
......@@ -42,7 +42,7 @@ def tower_func(image):
def run_test(path, input):
param_dict = dict(np.load(path))
predictor = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 227, 227, 3), 'input')],
input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
input_names=['input'],
......
......@@ -97,7 +97,7 @@ def CPM(image):
def run_test(model_path, img_file):
param_dict = dict(np.load(model_path))
predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 368, 368, 3), 'input')],
input_signature=[tf.TensorSpec((None, 368, 368, 3), tf.float32, 'input')],
tower_func=CPM,
session_init=DictRestore(param_dict),
input_names=['input'],
......
......@@ -59,7 +59,7 @@ def run_test(path, input):
param_dict = {k.replace('/W', '/kernel').replace('/b', '/bias'): v for k, v in six.iteritems(param_dict)}
predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')],
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
input_names=['input'],
......
......@@ -62,7 +62,7 @@ def run_test(path, input):
param_dict = {k.replace('/W', '/kernel').replace('/b', '/bias'): v for k, v in six.iteritems(param_dict)}
predict_func = OfflinePredictor(PredictConfig(
inputs_desc=[InputDesc(tf.float32, (None, 224, 224, 3), 'input')],
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
input_names=['input'],
......
......@@ -88,7 +88,7 @@ class GANTrainer(TowerTrainer):
input = StagingInput(input)
# Setup input
cbs = input.setup(model.get_inputs_desc())
cbs = input.setup(model.get_input_signature())
self.register_callback(cbs)
if num_gpu <= 1:
......@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
......@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
self.tower_func = TowerFuncWrapper(get_cost, model.get_input_signature())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)),
......@@ -163,11 +163,11 @@ class SeparateGANTrainer(TowerTrainer):
assert min(d_period, g_period) == 1
# Setup input
cbs = input.setup(model.get_inputs_desc())
cbs = input.setup(model.get_input_signature())
self.register_callback(cbs)
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True):
# should not hook the updates to both train_op, it will hurt training speed.
......
......@@ -254,14 +254,11 @@ if __name__ == '__main__':
eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
elif args.flops:
# manually build the graph with batch=1
input_desc = [
InputDesc(tf.float32, [1, 224, 224, 3], 'input'),
InputDesc(tf.int32, [1], 'label')
]
input = PlaceholderInput()
input.setup(input_desc)
with TowerContext('', is_training=False):
model.build_graph(*input.get_input_tensors())
model.build_graph(
tf.placeholder(tf.float32, [1, 224, 224, 3], 'input'),
tf.placeholder(tf.int32, [1], 'label')
)
model_utils.describe_trainable_vars()
tf.profiler.profile(
......
......@@ -64,7 +64,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
correct = tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32, name='correct')
correct = tf.cast(tf.nn.in_top_k(predictions=logits, targets=label, k=1), tf.float32, name='correct')
# monitor training error
add_moving_summary(tf.reduce_mean(correct, name='accuracy'))
......@@ -76,7 +76,7 @@ class Model(ModelDesc):
return tf.add_n([cost, wd_cost], name='cost')
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=1e-2, trainable=False)
lr = tf.Variable(1e-2, name='learning_rate', trainable=False)
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr, epsilon=1e-3)
......
......@@ -9,7 +9,7 @@ import os
import tensorflow as tf
from tensorflow.python.keras.layers import *
from tensorpack import InputDesc, SyncMultiGPUTrainerReplicated
from tensorpack import SyncMultiGPUTrainerReplicated
from tensorpack.callbacks import *
from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import FakeData, MapDataComponent
......@@ -166,8 +166,8 @@ if __name__ == '__main__':
M = KerasModel(
resnet50,
inputs_desc=[InputDesc(tf.uint8, [None, 224, 224, 3], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 1000], 'labels')],
input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.uint8, 'images')],
target_signature=[tf.TensorSpec([None, 1000], tf.float32, 'labels')],
input=df_train,
trainer=SyncMultiGPUTrainerReplicated(num_gpu))
......
......@@ -7,7 +7,7 @@ import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorpack import InputDesc, QueueInput
from tensorpack import QueueInput
from tensorpack.callbacks import ModelSaver
from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import BatchData, MapData, dataset
......@@ -57,8 +57,8 @@ if __name__ == '__main__':
M = KerasModel(
model_func,
inputs_desc=[InputDesc(tf.float32, [None, IMAGE_SIZE, IMAGE_SIZE, 1], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 10], 'labels')],
input_signature=[tf.TensorSpec([None, IMAGE_SIZE, IMAGE_SIZE, 1], tf.float32, 'images')],
target_signature=[tf.TensorSpec([None, 10], tf.float32, 'labels')],
input=QueueInput(dataset_train))
M.compile(
optimizer=tf.train.AdamOptimizer(1e-3),
......
......@@ -141,7 +141,7 @@ class InferenceRunner(InferenceRunnerBase):
if self._tower_func is None:
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
self._tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(self._tower_func.inputs_desc)
input_callbacks = self._input_source.setup(self._tower_func.input_signature)
vs_name = self.trainer._vs_name_for_predictor(self._device_id)
logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
......@@ -223,7 +223,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
self._tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(self._tower_func.inputs_desc)
input_callbacks = self._input_source.setup(self._tower_func.input_signature)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, dev in enumerate(self._devices):
vs_name = self.trainer._vs_name_for_predictor(idx)
......
......@@ -8,8 +8,8 @@ import numpy as np
from abc import ABCMeta, abstractmethod
from collections import deque
import six
import tensorflow as tf
from ..compat import tfv1
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from .base import Callback
......@@ -67,7 +67,7 @@ class GraphVarParam(HyperParam):
def setup_graph(self):
""" Will setup the assign operator for that variable. """
all_vars = tf.global_variables() + tf.local_variables()
all_vars = tfv1.global_variables() + tfv1.local_variables()
for v in all_vars:
if v.name == self.var_name:
self.var = v
......
......@@ -141,7 +141,7 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer(
trainer, get_model,
inputs_desc, targets_desc,
input_signature, target_signature,
input, optimizer, loss, metrics):
"""
Args:
......@@ -159,7 +159,7 @@ def setup_keras_trainer(
assert isinstance(metrics, list), metrics
model_caller = KerasModelCaller(get_model)
nr_inputs = len(inputs_desc)
nr_inputs = len(input_signature)
def get_cost(*inputs):
ctx = get_current_tower_context()
......@@ -211,7 +211,7 @@ def setup_keras_trainer(
return total_loss
trainer.setup_graph(
inputs_desc + targets_desc,
input_signature + target_signature,
input,
get_cost,
lambda: optimizer)
......@@ -221,23 +221,27 @@ def setup_keras_trainer(
class KerasModel(object):
def __init__(self, get_model, inputs_desc, targets_desc,
input, trainer=None):
def __init__(self, get_model, input_signature=None, target_signature=None,
input=None, trainer=None, inputs_desc=None, targets_desc=None):
"""
Args:
get_model (input1, input2, ... -> keras.Model):
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
input_signature ([tf.TensorSpec]): required. The signature for inputs.
target_signature ([tf.TensorSpec]): required. The signature for the targets tensors.
input (InputSource | DataFlow): the InputSource or DataFlow where the input data comes from.
trainer (Trainer): the default will check the number of available GPUs and use them all.
inputs_desc, targets_desc: deprecated names for `input_signature` and `target_signature`
"""
if inputs_desc is not None:
input_signature = inputs_desc
if targets_desc is not None:
target_signature = targets_desc
self.get_model = get_model
assert callable(get_model), get_model
self.inputs_desc = inputs_desc
self.targets_desc = targets_desc
self.input_signature = input_signature
self.target_signature = target_signature
if trainer is None:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
......@@ -248,6 +252,7 @@ class KerasModel(object):
assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase)
assert input is not None, "Argument 'input' is required!"
self.input = apply_default_prefetch(input, trainer)
self.trainer = trainer
......@@ -267,7 +272,8 @@ class KerasModel(object):
self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME]
setup_keras_trainer(
self.trainer, get_model=self.get_model,
inputs_desc=self.inputs_desc, targets_desc=self.targets_desc,
input_signature=self.input_signature,
target_signature=self.target_signature,
input=self.input,
optimizer=optimizer,
loss=loss,
......
......@@ -18,12 +18,41 @@ TensorSpec = backport_tensor_spec()
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
def build_or_reuse_placeholder(tensor_spec):
"""
Build a tf.placeholder from the metadata in the given tensor spec, or return an existing one.
Args:
tensor_spec (tf.TensorSpec):
Returns:
tf.Tensor:
"""
g = tfv1.get_default_graph()
name = tensor_spec.name
try:
tensor = g.get_tensor_by_name(name + ':0')
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor)
return tensor
except KeyError:
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
tensor_spec.dtype, shape=tensor_spec.shape, name=tensor_spec.name)
return ret
class InputDesc(
namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
"""
Metadata about an input entry point to the graph.
This metadata can be later used to build placeholders or other types of
input source.
An equivalent of `tf.TensorSpec`.
History: this concept is used to represent metadata about the inputs,
which can be later used to build placeholders or other types of input source.
It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
was introduced in TensorFlow.
Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
"""
def __new__(cls, type, shape, name):
......@@ -33,64 +62,9 @@ class InputDesc(
shape (tuple):
name (str):
"""
shape = tuple(shape) # has to be tuple for "self" to be hashable
# TODO mark deprecated
assert isinstance(type, tf.DType), type
if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid InputDesc name: '{}'".format(name))
self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = {}
return self
def _build_placeholder(self):
"""
Build a tf.placeholder from the metadata.
Returns:
tf.Tensor:
"""
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
self.type, shape=self.shape, name=self.name)
self._register_cached_placeholder(ret)
return ret
# cannot memoize here, because InputDesc is hashed by its fields.
def build_placeholder_reuse(self):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
g = tfv1.get_default_graph()
if g in self._cached_placeholder:
return self._cached_placeholder[g]
else:
return self._build_placeholder()
def _register_cached_placeholder(self, placeholder):
graph = placeholder.graph
assert graph not in self._cached_placeholder, \
"Placeholder for this InputDesc had been created before! This is a bug."
self._cached_placeholder[graph] = placeholder
@staticmethod
def _from_placeholder(placeholder):
name = placeholder.op.name
if name.endswith('_1') or name.endswith('_2'):
logger.error("Creating InputDesc from a placeholder named {}.".format(name))
logger.error("You might have mistakenly created this placeholder multiple times!")
ret = InputDesc(
placeholder.dtype,
tuple(placeholder.shape.as_list()),
name)
ret._register_cached_placeholder(placeholder)
return ret
@staticmethod
def _from_tensor_spec(spec):
assert spec.name is not None, "TensorSpec should have a name!"
return InputDesc(spec.dtype, tuple(spec.shape.as_list()), spec.name)
return tf.TensorSpec(shape=shape, dtype=type, name=name)
class ModelDescBase(object):
......@@ -100,29 +74,22 @@ class ModelDescBase(object):
@memoized_method
def get_inputs_desc(self):
# TODO mark deprecated
return self.get_input_signature()
@memoized_method
def get_input_signature(self):
"""
Returns:
A list of :class:`InputDesc`, which describes the inputs of this model.
A list of :class:`tf.TensorSpec`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
"""
try:
ret = self._get_inputs()
log_deprecated(
"ModelDescBase._get_inputs() interface",
"Use inputs() instead!",
"2019-03-30")
return ret
except NotImplementedError:
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
if isinstance(inputs[0], tf.Tensor):
for p in inputs:
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [InputDesc._from_placeholder(p) for p in inputs]
else:
for p in inputs:
assert isinstance(p, TensorSpec), type(p)
return [InputDesc._from_tensor_spec(p) for p in inputs]
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
if isinstance(inputs[0], tf.Tensor):
for p in inputs:
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [TensorSpec(shape=p.shape, dtype=p.dtype, name=p.name) for p in inputs]
@property
def input_names(self):
......@@ -130,7 +97,7 @@ class ModelDescBase(object):
Returns:
[str]: the names of all the inputs.
"""
return [k.name for k in self.get_inputs_desc()]
return [k.name for k in self.get_input_signature()]
def _get_inputs(self):
raise NotImplementedError()
......@@ -147,7 +114,7 @@ class ModelDescBase(object):
Also, you should never call this method by yourself.
Returns:
list[tf.placeholder] or list[tf.TensorSpec], to be converted to :class:`InputDesc`.
list[tf.TensorSpec or tf.placeholder]. To be converted to :class:`tf.TensorSpec`.
"""
raise NotImplementedError()
......@@ -166,9 +133,9 @@ class ModelDescBase(object):
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
assert len(args) == len(self.get_inputs_desc()), \
assert len(args) == len(self.get_input_signature()), \
"Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(args), len(self.get_inputs_desc()))
"in ModelDesc! ({} != {})".format(len(args), len(self.get_input_signature()))
log_deprecated(
"ModelDescBase._build_graph() interface",
"Use build_graph() instead!",
......
......@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from .input_source_base import InputSource
from ..graph_builder.model_desc import build_or_reuse_placeholder
try:
from tensorflow.python.ops.data_flow_ops import StagingArea
......@@ -59,7 +60,7 @@ class PlaceholderInput(InputSource):
pass
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
def _get_input_tensors(self):
return self._all_placehdrs
......@@ -110,7 +111,7 @@ class FeedInput(InputSource):
def _setup(self, inputs):
# placeholders as input are always safe to reuse.
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
def _get_input_tensors(self):
......@@ -196,7 +197,7 @@ class QueueInput(FeedfreeInput):
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
should match the corresponding input signature of the model.
Defaults to a FIFO queue of size 50.
"""
if not isinstance(ds, DataFlow):
......@@ -210,12 +211,12 @@ class QueueInput(FeedfreeInput):
return len(self.ds)
def _setup(self, inputs):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!"
with self.cached_name_scope():
if self.queue is None:
self.queue = tf.FIFOQueue(
self.queue = tfv1.FIFOQueue(
50, [x.dtype for x in self._input_placehdrs],
name='input_queue')
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
......@@ -287,7 +288,7 @@ class BatchQueueInput(QueueInput):
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
should match the corresponding input signature of the model.
Defaults to a FIFO queue of size 3000.
"""
super(BatchQueueInput, self).__init__(ds, queue)
......@@ -298,9 +299,9 @@ class BatchQueueInput(QueueInput):
def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self.input_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with some InputDesc!"
"BatchQueueInput has to be used with some input signature!"
# prepare placeholders without the first dimension
placehdrs_nobatch = []
......@@ -364,8 +365,8 @@ class TensorInput(FeedfreeInput):
assert size > 0
self._fixed_size = size
def _setup(self, inputs_desc):
self._desc = inputs_desc
def _setup(self, input_signature):
self._spec = input_signature
def _size(self):
if self._fixed_size is None:
......@@ -376,8 +377,8 @@ class TensorInput(FeedfreeInput):
with self.cached_name_scope():
ret = self.get_tensor_fn()
assert isinstance(ret, (list, tuple)), "get_tensor_fn needs to return a list!"
assert len(ret) == len(self._desc), \
"get_tensor_fn returns {} tensors but there are {} inputs".format(len(ret), len(self._desc))
assert len(ret) == len(self._spec), \
"get_tensor_fn returns {} tensors but there are {} inputs".format(len(ret), len(self._spec))
return ret
......@@ -399,7 +400,7 @@ class DummyConstantInput(TensorInput):
assert len(self.shapes) == len(self._desc)
for idx, p in enumerate(self._desc):
tlist.append(tf.constant(
0, dtype=p.type,
0, dtype=p.dtype,
name='dummy-{}-{}'.format(p.name, ctx.index),
shape=self.shapes[idx]))
return tlist
......@@ -429,15 +430,14 @@ class ZMQInput(TensorInput):
return ret
super(ZMQInput, self).__init__(fn)
def _setup(self, inputs_desc):
assert len(inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!"
self._desc = inputs_desc
def _setup(self, input_signature):
assert len(input_signature) > 0, \
"ZMQInput has to be used with input signature!"
import zmq_ops
self._zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point,
[x.type for x in inputs_desc],
[x.dtype for x in input_signature],
hwm=self._hwm,
bind=self._bind)
......@@ -458,23 +458,23 @@ class TFDatasetInput(FeedfreeInput):
raise ValueError("TFDatasetInput takes a tf.data.Dataset! Got {}".format(dataset))
self._dataset = dataset
def _setup(self, inputs_desc):
self._desc = inputs_desc
def _setup(self, input_signature):
self._spec = input_signature
types = self._dataset.output_types
desc_types = tuple([k.type for k in inputs_desc])
assert len(types) == len(desc_types), \
"Dataset and InputDesc has different length! {} != {}".format(
len(types), len(desc_types))
assert types == desc_types, \
"Types of dataset and InputDesc don't match! {} != {}".format(
str(types), str(desc_types))
spec_types = tuple([k.dtype for k in input_signature])
assert len(types) == len(spec_types), \
"Dataset and input signature have different length! {} != {}".format(
len(types), len(spec_types))
assert types == spec_types, \
"Data types of dataset and input signature don't match! {} != {}".format(
str(types), str(spec_types))
shapes = self._dataset.output_shapes
desc_shapes = [k.shape for k in inputs_desc]
for idx, (s1, s2) in enumerate(zip(shapes, desc_shapes)):
spec_shapes = [k.shape for k in input_signature]
for idx, (s1, s2) in enumerate(zip(shapes, spec_shapes)):
s2 = tf.TensorShape(s2)
assert s2.is_compatible_with(s1), \
"InputDesc '{}' has incompatible shape with dataset! {} vs {}".format(
inputs_desc[idx].name, s2, s1)
"Input signature '{}' has incompatible shape with dataset! {} vs {}".format(
input_signature[idx].name, s2, s1)
self._iterator = self._dataset.make_initializable_iterator()
self._init_op = self._iterator.initializer
......@@ -482,11 +482,11 @@ class TFDatasetInput(FeedfreeInput):
self._init_op.run()
def _get_input_tensors(self):
desc_shapes = [k.shape for k in self._desc]
spec_shapes = [k.shape for k in self._spec]
ret = self._iterator.get_next()
assert len(ret) == len(desc_shapes), \
"Dataset returns {} tensors but there are {} inputs!".format(len(ret), len(desc_shapes))
for t, shp in zip(ret, desc_shapes):
assert len(ret) == len(spec_shapes), \
"Dataset returns {} tensors but there are {} inputs!".format(len(ret), len(spec_shapes))
for t, shp in zip(ret, spec_shapes):
t.set_shape(shp)
return ret
......
......@@ -12,6 +12,7 @@ from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.argtools import call_only_once, memoized_method
from ..graph_builder.model_desc import build_or_reuse_placeholder
__all__ = ['InputSource', 'remap_input_source']
......@@ -86,20 +87,20 @@ class InputSource(object):
pass
@call_only_once
def setup(self, inputs_desc):
def setup(self, input_signature):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
input_signature (list[tf.TensorSpec]): list of specs for each input tensor
Returns:
list[Callback]: extra callbacks needed by this InputSource.
callbacks of InputSource cannot use any `trigger*()` method.
"""
self._setup(inputs_desc)
self._setup(input_signature)
self._setup_done = True
return self.get_callbacks()
def _setup(self, inputs_desc):
def _setup(self, input_signature):
pass
def setup_done(self):
......@@ -190,8 +191,8 @@ class ProxyInputSource(InputSource):
def _get_input_tensors(self):
return self._input.get_input_tensors()
def _setup(self, inputs_desc):
self._input.setup(inputs_desc)
def _setup(self, input_signature):
self._input.setup(input_signature)
def _get_callbacks(self):
return self._input.get_callbacks()
......@@ -226,11 +227,11 @@ def remap_input_source(input, names):
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input_signature = [tf.TensorSpec((None,10), tf.float32, 'score'),
tf.TensorSpec((None,20,20,3), tf.float32, 'label'),
tf.TensorSpec((None,), tf.int32, 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
input2.setup(input_signature)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
......@@ -240,7 +241,7 @@ def remap_input_source(input, names):
self._names = tuple(names)
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
......
......@@ -155,7 +155,7 @@ class OfflinePredictor(OnlinePredictor):
self.graph = config._maybe_create_graph()
with self.graph.as_default():
input = PlaceholderInput()
input.setup(config.inputs_desc)
input.setup(config.input_signature)
with PredictTowerContext(''):
config.tower_func(*input.get_input_tensors())
......
......@@ -18,7 +18,7 @@ class PredictConfig(object):
def __init__(self,
model=None,
tower_func=None,
inputs_desc=None,
input_signature=None,
input_names=None,
output_names=None,
......@@ -27,11 +27,18 @@ class PredictConfig(object):
session_init=None,
return_input=False,
create_graph=True,
inputs_desc=None
):
"""
You need to set either `model`, or `inputs_desc` plus `tower_func`.
They are needed to construct the graph.
You'll also have to set `output_names` as it does not have a default.
Users need to provide enough arguments to create a tower function,
which will be used to construct the graph.
This can be provided in the following ways:
1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself.
2. `tower_func`: a :class:`tfutils.TowerFuncWrapper` instance.
Provide a tower function instance directly.
3. `tower_func`: a symbolic function and `input_signature`: the signature of the function.
Provide both a function and its signature.
Example:
......@@ -42,15 +49,14 @@ class PredictConfig(object):
output_names=['linear/output', 'prediction'])
Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
model (ModelDescBase): to be used to construct a tower function.
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance, which packs both `inputs_desc` and function together.
inputs_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes
the list of inputs it takes.
or a :class:`tfutils.TowerFuncWrapper` instance.
input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFuncWrapper),
this describes the list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
The name can be either the name of a tensor, or the name of one input defined
by `inputs_desc` or by `model`.
input_names (list): a list of input tensor names. Defaults to match input_signature.
The name can be either the name of a tensor, or the name of one input of the tower.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`.
......@@ -62,23 +68,29 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
inputs_desc (list[tf.TensorSpec]): old (deprecated) name for `input_signature`.
"""
def assert_type(v, tp, name):
assert isinstance(v, tp), \
"{} has to be type '{}', but an object of type '{}' found.".format(
"Argument '{}' has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__)
if inputs_desc is not None:
# TODO warn deprecated or not?
assert input_signature is None, "Cannot set both inputs_desc and input_signature!"
input_signature = inputs_desc
if model is not None:
assert_type(model, ModelDescBase, 'model')
assert inputs_desc is None and tower_func is None
self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
assert input_signature is None and tower_func is None
self.input_signature = model.get_input_signature()
self.tower_func = TowerFuncWrapper(model.build_graph, self.input_signature)
else:
if isinstance(tower_func, TowerFuncWrapper):
inputs_desc = tower_func.inputs_desc
assert inputs_desc is not None and tower_func is not None
self.inputs_desc = inputs_desc
self.tower_func = TowerFuncWrapper(tower_func, inputs_desc)
input_signature = tower_func.input_signature
assert input_signature is not None and tower_func is not None
self.input_signature = input_signature
self.tower_func = TowerFuncWrapper(tower_func, input_signature)
if session_init is None:
session_init = JustCurrentSession()
......@@ -93,20 +105,22 @@ class PredictConfig(object):
# inputs & outputs
self.input_names = input_names
if self.input_names is None:
self.input_names = [k.name for k in self.inputs_desc]
self.input_names = [k.name for k in self.input_signature]
assert output_names is not None, "Argument 'output_names' is not provided!"
self.output_names = output_names
assert_type(self.output_names, list, 'output_names')
assert_type(self.input_names, list, 'input_names')
if len(self.input_names) == 0:
logger.warn('PredictConfig receives empty "input_names".')
# assert len(self.input_names), self.input_names
for v in self.input_names:
assert_type(v, six.string_types, 'Each item in input_names')
assert len(self.output_names), self.output_names
assert len(self.output_names), "Argument 'output_names' cannot be empty!"
self.return_input = bool(return_input)
self.create_graph = bool(create_graph)
self.inputs_desc = input_signature # TODO a little bit of compatibility
def _maybe_create_graph(self):
if self.create_graph:
return tf.Graph()
......
......@@ -21,7 +21,7 @@ class FeedfreePredictor(PredictorBase):
Args:
config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use.
Must match the inputs_desc in config.
Must match the signature of the tower function in config.
"""
self._config = config
self._input_source = input_source
......@@ -33,7 +33,7 @@ class FeedfreePredictor(PredictorBase):
self.graph = config._maybe_create_graph()
with self.graph.as_default():
self._input_callbacks = Callbacks(
self._input_source.setup(config.inputs_desc))
self._input_source.setup(config.input_signature))
with PredictTowerContext(''):
self._input_tensors = self._input_source.get_input_tensors()
config.tower_func(*self._input_tensors)
......
......@@ -4,7 +4,6 @@
import tensorflow as tf
from ..graph_builder.model_desc import InputDesc
from ..input_source import PlaceholderInput
from ..tfutils.tower import PredictTowerContext
from ..utils import logger
......@@ -33,7 +32,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
handles = []
input = PlaceholderInput()
input.setup(config.inputs_desc)
input.setup(config.input_signature)
for idx, t in enumerate(towers):
tower_name = 'tower' + str(t)
......@@ -102,10 +101,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for idx, t in enumerate(towers):
tower_name = 'tower' + str(t)
inputs_desc = [InputDesc(desc.type, desc.shape, tower_name + '_' + desc.name)
for desc in config.inputs_desc]
new_sig = [tf.TensorSpec(dtype=p.dtype, shape=p.shape, name=tower_name + '_' + p.name)
for p in config.input_signature]
input = PlaceholderInput()
input.setup(inputs_desc)
input.setup(new_sig)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0), \
tf.device('/gpu:{}'.format(t)), \
......
......@@ -28,7 +28,7 @@ class ModelExporter(object):
Args:
config (PredictConfig): the config to use.
The graph will be built with `config.tower_func` and `config.inputs_desc`.
The graph will be built with the tower function defined by this `PredictConfig`.
Then the input / output names will be used to export models for inference.
"""
super(ModelExporter, self).__init__()
......@@ -51,7 +51,7 @@ class ModelExporter(object):
self.graph = self.config._maybe_create_graph()
with self.graph.as_default():
input = PlaceholderInput()
input.setup(self.config.inputs_desc)
input.setup(self.config.input_signature)
with PredictTowerContext(''):
self.config.tower_func(*input.get_input_tensors())
......@@ -116,7 +116,7 @@ class ModelExporter(object):
self.graph = self.config._maybe_create_graph()
with self.graph.as_default():
input = PlaceholderInput()
input.setup(self.config.inputs_desc)
input.setup(self.config.input_signature)
with PredictTowerContext(''):
self.config.tower_func(*input.get_input_tensors())
......
......@@ -257,24 +257,27 @@ class TowerFuncWrapper(object):
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
"""
def __init__(self, tower_fn, inputs_desc):
def __init__(self, tower_fn, input_signature):
"""
Args:
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): list of :class:`InputDesc`.
input_signature ([TensorSpec]): list of :class:`tf.TensorSpec`.
They are used to figure out the names for the input tensors.
"""
assert callable(tower_fn), tower_fn
self._inputs_desc_names = [k.name for k in inputs_desc]
assert len(set(self._inputs_desc_names)) == len(self._inputs_desc_names), \
"Duplicated names in inputs_desc! " + str(self._inputs_desc_names)
self._inputs_names = [k.name for k in input_signature]
assert len(set(self._inputs_names)) == len(self._inputs_names), \
"Duplicated names in input_signature! " + str(self._inputs_names)
for name in self._inputs_names:
if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid input name: '{}'".format(name))
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
self._input_signature = input_signature
self._handles = []
def __new__(cls, tower_fn, inputs_desc):
def __new__(cls, tower_fn, _):
# to avoid double-wrapping a function
if isinstance(tower_fn, TowerFuncWrapper):
return tower_fn
......@@ -285,7 +288,7 @@ class TowerFuncWrapper(object):
ctx = get_current_tower_context()
assert ctx is not None, "Function must be called under TowerContext!"
output = self._tower_fn(*args)
handle = TowerTensorHandle(ctx, args, output, self._inputs_desc)
handle = TowerTensorHandle(ctx, args, output, self._input_signature)
self._handles.append(handle)
return output
......@@ -298,9 +301,14 @@ class TowerFuncWrapper(object):
"""
return TowerTensorHandles(self._handles)
@property
def input_signature(self):
return self._input_signature
@property
def inputs_desc(self):
return self._inputs_desc
# TODO mark deprecated
return self._input_signature
class TowerTensorHandles(object):
......@@ -354,14 +362,14 @@ class TowerTensorHandle(object):
"""
@HIDE_DOC
def __init__(self, ctx, input, output, inputs_desc=None):
def __init__(self, ctx, input, output, input_signature=None):
self._ctx = ctx
self._extra_tensor_names = {}
if inputs_desc is not None:
assert len(inputs_desc) == len(input)
if input_signature is not None:
assert len(input_signature) == len(input)
self._extra_tensor_names = {
get_op_tensor_name(x.name)[1]: y for x, y in zip(inputs_desc, input)}
get_op_tensor_name(x.name)[1]: y for x, y in zip(input_signature, input)}
self._input = input
self._output = output
......@@ -379,7 +387,7 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower.
2. A name in the input signature, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
......
......@@ -87,7 +87,7 @@ def launch_train_with_config(config, trainer):
# We should gradually stay away from this unuseful abstraction.
# TowerFuncWrapper is a better abstraction (similar to tf.defun in the future)
trainer.setup_graph(
model.get_inputs_desc(), input,
model.get_input_signature(), input,
model._build_graph_get_cost, model.get_optimizer)
_check_unused_regularization()
trainer.train_with_defaults(
......
......@@ -56,11 +56,16 @@ class TowerTrainer(Trainer):
@property
def inputs_desc(self):
# TODO mark deprecated
return self.input_signature
@property
def input_signature(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
list[tf.TensorSpec]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
return self.tower_func.input_signature
@property
def towers(self):
......@@ -124,7 +129,7 @@ class TowerTrainer(Trainer):
if tower is None:
input = PlaceholderInput()
input.setup(self.inputs_desc)
input.setup(self.input_signature)
vs_name = self._vs_name_for_predictor(device_id)
with tfv1.variable_scope(tfv1.get_variable_scope(), reuse=True), \
......@@ -164,7 +169,7 @@ class SingleCostTrainer(TowerTrainer):
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training graph from them.
(input_signature, input, get_cost_fn, get_opt_fn), and build the training graph from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
......@@ -194,12 +199,12 @@ class SingleCostTrainer(TowerTrainer):
"""
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
def setup_graph(self, input_signature, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input_signature ([TensorSpec]): list of TensorSpec that describe the inputs
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
......@@ -210,12 +215,12 @@ class SingleCostTrainer(TowerTrainer):
It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_cost_fn = TowerFuncWrapper(get_cost_fn, input_signature)
get_opt_fn = memoized(get_opt_fn)
self.tower_func = get_cost_fn
# TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input)
input_callbacks = self._setup_input(input_signature, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self.register_callback(input_callbacks + train_callbacks)
......@@ -229,9 +234,9 @@ class SingleCostTrainer(TowerTrainer):
[Callback]: list of callbacks needed
"""
def _setup_input(self, inputs_desc, input):
def _setup_input(self, input_signature, input):
assert not input.setup_done()
return input.setup(inputs_desc)
return input.setup(input_signature)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
......
......@@ -272,7 +272,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
self._builder = DistributedReplicatedBuilder(gpus, server)
self.is_chief = self._builder.is_chief
def _setup_input(self, inputs_desc, input):
def _setup_input(self, input_signature, input):
with override_to_local_variable():
get_global_step_var() # gs should be local
# input source may create variable (queue size summary)
......@@ -280,7 +280,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
# whether something should be global or local. We now assume
# they should be local.
assert not input.setup_done()
return input.setup(inputs_desc)
return input.setup(input_signature)
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, FeedfreeInput), input
......
......@@ -132,8 +132,8 @@ class ShareSessionThread(threading.Thread):
yield None
def start(self):
import tensorflow as tf
self._sess = tf.get_default_session()
from ..compat import tfv1
self._sess = tfv1.get_default_session()
super(ShareSessionThread, self).start()
def run(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment