Commit 5b29bda9 authored by Yuxin Wu's avatar Yuxin Wu

use InputDesc instead of InputVar.

parent bd686aab
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
* 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was
renamed to `InputDesc`.
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/a9dd0b8ec34209ab86a92875589dbbc4716e73ef). * 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/a9dd0b8ec34209ab86a92875589dbbc4716e73ef).
* 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in * 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in
TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe). TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe).
......
...@@ -145,8 +145,7 @@ If you are surprised how far we already are, you will enjoy how easy it is to de ...@@ -145,8 +145,7 @@ If you are surprised how far we already are, you will enjoy how easy it is to de
```python ```python
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self):
def _get_input_vars(self):
pass pass
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
...@@ -154,7 +153,7 @@ class Model(ModelDesc): ...@@ -154,7 +153,7 @@ class Model(ModelDesc):
``` ```
The framework expects: The framework expects:
- a definition of inputs in `_get_input_vars` - a definition of inputs in `_get_inputs`
- a computation graph containing the actual network layers in `_build_graph` - a computation graph containing the actual network layers in `_build_graph`
- In single-cost optimization problem, a member `self.cost` representing the loss function we would like to minimize. - In single-cost optimization problem, a member `self.cost` representing the loss function we would like to minimize.
...@@ -163,23 +162,23 @@ Our dataflow produces data which looks like `[(32, 256, 256), (32, 256, 256, 3)] ...@@ -163,23 +162,23 @@ Our dataflow produces data which looks like `[(32, 256, 256), (32, 256, 256, 3)]
The first entry is the luminance channel as input and the latter is the original RGB image with all three channels. So we will write The first entry is the luminance channel as input and the latter is the original RGB image with all three channels. So we will write
```python ```python
def _get_input_vars(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 256, 256), 'luminance'), return [InputDesc(tf.float32, (None, 256, 256), 'luminance'),
InputVar(tf.int32, (None, 256, 256, 3), 'rgb')] InputDesc(tf.int32, (None, 256, 256, 3), 'rgb')]
``` ```
This is pretty straight forward, isn't it? We defined the shapes of the input and give each entry a name. This is pretty straight forward, isn't it? We defined the shapes of the input and give each entry a name.
You can certainly use 32 instead of `None`, but since the model itself doesn't really need to know You can certainly use 32 instead of `None`, but since the model itself doesn't really need to know
the batch size, using `None` offers the extra flexibility to run inference with a different batch size in the same graph. the batch size, using `None` offers the extra flexibility to run inference with a different batch size in the same graph.
From now, the `input_vars` in `_build_graph(self, input_vars)` will be the tensors of the defined shapes in the method `_get_input_vars`. We can therefore write From now, the `input_vars` in `_build_graph(self, input_vars)` will be the tensors of the defined shapes in the method `_get_inputs`. We can therefore write
```python ```python
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 256, 256), 'luminance'), return [InputDesc(tf.float32, (None, 256, 256), 'luminance'),
InputVar(tf.int32, (None, 256, 256, 3), 'rgb')] InputDesc(tf.int32, (None, 256, 256, 3), 'rgb')]
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
luminance, rgb = input_vars # (None, 256, 256), (None, 256, 256, 3) luminance, rgb = input_vars # (None, 256, 256), (None, 256, 256, 3)
...@@ -354,7 +353,7 @@ class OnlineExport(Callback): ...@@ -354,7 +353,7 @@ class OnlineExport(Callback):
pass pass
``` ```
Can you remember the method `_get_input_vars` in our model? We used the name `luminance` to identify one input. Can you remember the method `_get_inputs` in our model? We used the name `luminance` to identify one input.
If not, this is the best time to go back in this text and read how to specify input variables for the network. If not, this is the best time to go back in this text and read how to specify input variables for the network.
In the deconvolution step there was also: In the deconvolution step there was also:
......
...@@ -7,7 +7,7 @@ you'll need to subclass `ModelDesc` and implement several methods: ...@@ -7,7 +7,7 @@ you'll need to subclass `ModelDesc` and implement several methods:
```python ```python
class MyModel(ModelDesc): class MyModel(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(...), InputVar(...)] return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs): def _build_graph(self, inputs):
# build the graph # build the graph
......
...@@ -41,9 +41,9 @@ def get_player(dumpdir=None): ...@@ -41,9 +41,9 @@ def get_player(dumpdir=None):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int32, (None,), 'action'), InputDesc(tf.int32, (None,), 'action'),
InputVar(tf.float32, (None,), 'futurereward')] InputDesc(tf.float32, (None,), 'futurereward')]
def _get_NN_prediction(self, image): def _get_NN_prediction(self, image):
image = image / 255.0 image = image / 255.0
......
...@@ -78,9 +78,9 @@ class MySimulatorWorker(SimulatorProcess): ...@@ -78,9 +78,9 @@ class MySimulatorWorker(SimulatorProcess):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int64, (None,), 'action'), InputDesc(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'futurereward')] InputDesc(tf.float32, (None,), 'futurereward')]
def _get_NN_prediction(self, image): def _get_NN_prediction(self, image):
image = image / 255.0 image = image / 255.0
......
...@@ -28,11 +28,11 @@ FEATUREDIM = 39 # MFCC feature dimension ...@@ -28,11 +28,11 @@ FEATUREDIM = 39 # MFCC feature dimension
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, None, FEATUREDIM], 'feat'), # bxmaxseqx39 return [InputDesc(tf.float32, [None, None, FEATUREDIM], 'feat'), # bxmaxseqx39
InputVar(tf.int64, None, 'labelidx'), # label is b x maxlen, sparse InputDesc(tf.int64, None, 'labelidx'), # label is b x maxlen, sparse
InputVar(tf.int32, None, 'labelvalue'), InputDesc(tf.int32, None, 'labelvalue'),
InputVar(tf.int64, None, 'labelshape'), InputDesc(tf.int64, None, 'labelshape'),
InputVar(tf.int32, [None], 'seqlen'), # b InputDesc(tf.int32, [None], 'seqlen'), # b
] ]
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -66,8 +66,8 @@ class CharRNNData(RNGDataFlow): ...@@ -66,8 +66,8 @@ class CharRNNData(RNGDataFlow):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.int32, (None, param.seq_len), 'input'), return [InputDesc(tf.int32, (None, param.seq_len), 'input'),
InputVar(tf.int32, (None, param.seq_len), 'nextinput')] InputDesc(tf.int32, (None, param.seq_len), 'nextinput')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
input, nextinput = inputs input, nextinput = inputs
......
...@@ -45,8 +45,8 @@ def get_gaussian_map(): ...@@ -45,8 +45,8 @@ def get_gaussian_map():
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 368, 368, 3), 'input'), return [InputDesc(tf.float32, (None, 368, 368, 3), 'input'),
InputVar(tf.float32, (None, 368, 368, 15), 'label'), InputDesc(tf.float32, (None, 368, 368, 15), 'label'),
] ]
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -72,11 +72,11 @@ class Model(ModelDesc): ...@@ -72,11 +72,11 @@ class Model(ModelDesc):
if NUM_ACTIONS is None: if NUM_ACTIONS is None:
p = get_player() p = get_player()
del p del p
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int64, (None,), 'action'), InputDesc(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'), InputDesc(tf.float32, (None,), 'reward'),
InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'), InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputVar(tf.bool, (None,), 'isOver')] InputDesc(tf.bool, (None,), 'isOver')]
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
""" image: [0,255]""" """ image: [0,255]"""
......
...@@ -75,8 +75,8 @@ BATCH_SIZE = None ...@@ -75,8 +75,8 @@ BATCH_SIZE = None
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, 224, 224, 3], 'input'), return [InputDesc(tf.float32, [None, 224, 224, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -34,8 +34,8 @@ BITG = 32 ...@@ -34,8 +34,8 @@ BITG = 32
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, 224, 224, 3], 'input'), return [InputDesc(tf.float32, [None, 224, 224, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -44,8 +44,8 @@ BITG = 4 ...@@ -44,8 +44,8 @@ BITG = 4
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, 40, 40, 3], 'input'), return [InputDesc(tf.float32, [None, 40, 40, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -27,9 +27,9 @@ BATCH = 128 ...@@ -27,9 +27,9 @@ BATCH = 128
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_input_vars(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 28, 28), 'input'), return [InputDesc(tf.float32, (None, 28, 28), 'input'),
InputVar(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def generator(self, z, y): def generator(self, z, y):
l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, nl=BNReLU) l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, nl=BNReLU)
......
...@@ -36,7 +36,7 @@ Z_DIM = 100 ...@@ -36,7 +36,7 @@ Z_DIM = 100
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, 3), 'input')] return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'input')]
def generator(self, z): def generator(self, z):
""" return a image generated from z""" """ return a image generated from z"""
......
...@@ -44,8 +44,8 @@ NF = 64 # number of filter ...@@ -44,8 +44,8 @@ NF = 64 # number of filter
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'), return [InputDesc(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'),
InputVar(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')] InputDesc(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')]
def generator(self, imgs): def generator(self, imgs):
# imgs: input: 256x256xch # imgs: input: 256x256xch
......
...@@ -41,7 +41,7 @@ class GaussianWithUniformSample(GaussianDistribution): ...@@ -41,7 +41,7 @@ class GaussianWithUniformSample(GaussianDistribution):
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 28, 28), 'input')] return [InputDesc(tf.float32, (None, 28, 28), 'input')]
def generator(self, z): def generator(self, z):
l = FullyConnected('fc0', z, 1024, nl=BNReLU) l = FullyConnected('fc0', z, 1024, nl=BNReLU)
......
...@@ -18,8 +18,8 @@ from tensorpack.tfutils.summary import * ...@@ -18,8 +18,8 @@ from tensorpack.tfutils.summary import *
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, None, None, 3], 'image'), return [InputDesc(tf.float32, [None, None, None, 3], 'image'),
InputVar(tf.int32, [None, None, None], 'edgemap')] InputDesc(tf.int32, [None, None, None], 'edgemap')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, edgemap = inputs image, edgemap = inputs
......
...@@ -30,8 +30,8 @@ Learning rate may need a different schedule for different number of GPUs (becaus ...@@ -30,8 +30,8 @@ Learning rate may need a different schedule for different number of GPUs (becaus
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -35,8 +35,8 @@ INPUT_SHAPE = 299 ...@@ -35,8 +35,8 @@ INPUT_SHAPE = 299
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -45,8 +45,8 @@ def get_PennTreeBank(data_dir=None): ...@@ -45,8 +45,8 @@ def get_PennTreeBank(data_dir=None):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.int32, (None, SEQ_LEN), 'input'), return [InputDesc(tf.int32, (None, SEQ_LEN), 'input'),
InputVar(tf.int32, (None, SEQ_LEN), 'nextinput')] InputDesc(tf.int32, (None, SEQ_LEN), 'nextinput')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
......
...@@ -38,8 +38,8 @@ class Model(ModelDesc): ...@@ -38,8 +38,8 @@ class Model(ModelDesc):
self.n = n self.n = n
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, 32, 32, 3], 'input'), return [InputDesc(tf.float32, [None, 32, 32, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -29,8 +29,8 @@ DEPTH = None ...@@ -29,8 +29,8 @@ DEPTH = None
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -25,10 +25,9 @@ MODEL_DEPTH = None ...@@ -25,10 +25,9 @@ MODEL_DEPTH = None
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self):
def _get_input_vars(self): return [InputDesc(tf.float32, [None, 224, 224, 3], 'input'),
return [InputVar(tf.float32, [None, 224, 224, 3], 'input'), InputDesc(tf.int32, [None], 'label')]
InputVar(tf.int32, [None], 'label')]
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
image, label = input_vars image, label = input_vars
......
...@@ -17,7 +17,7 @@ IMAGE_SIZE = 224 ...@@ -17,7 +17,7 @@ IMAGE_SIZE = 224
class Model(tp.ModelDesc): class Model(tp.ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [tp.InputVar(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')] return [tp.InputDesc(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
orig_image = inputs[0] orig_image = inputs[0]
......
...@@ -62,9 +62,9 @@ class SiameseModel(EmbeddingModel): ...@@ -62,9 +62,9 @@ class SiameseModel(EmbeddingModel):
return ds return ds
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 28, 28), 'input'), return [InputDesc(tf.float32, (None, 28, 28), 'input'),
InputVar(tf.float32, (None, 28, 28), 'input_y'), InputDesc(tf.float32, (None, 28, 28), 'input_y'),
InputVar(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
# get inputs # get inputs
...@@ -105,9 +105,9 @@ class TripletModel(EmbeddingModel): ...@@ -105,9 +105,9 @@ class TripletModel(EmbeddingModel):
return ds return ds
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 28, 28), 'input'), return [InputDesc(tf.float32, (None, 28, 28), 'input'),
InputVar(tf.float32, (None, 28, 28), 'input_p'), InputDesc(tf.float32, (None, 28, 28), 'input_p'),
InputVar(tf.float32, (None, 28, 28), 'input_n')] InputDesc(tf.float32, (None, 28, 28), 'input_n')]
def loss(self, a, p, n): def loss(self, a, p, n):
return symbf.triplet_loss(a, p, n, 5., extra=True, scope="loss") return symbf.triplet_loss(a, p, n, 5., extra=True, scope="loss")
......
...@@ -20,8 +20,8 @@ HALF_DIFF = (IMAGE_SIZE - WARP_TARGET_SIZE) // 2 ...@@ -20,8 +20,8 @@ HALF_DIFF = (IMAGE_SIZE - WARP_TARGET_SIZE) // 2
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, 2), 'input'), return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, 2), 'input'),
InputVar(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
xys = np.array([(y, x, 1) for y in range(WARP_TARGET_SIZE) xys = np.array([(y, x, 1) for y in range(WARP_TARGET_SIZE)
......
...@@ -29,8 +29,8 @@ class Model(ModelDesc): ...@@ -29,8 +29,8 @@ class Model(ModelDesc):
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, 30, 30, 3], 'input'), return [InputDesc(tf.float32, [None, 30, 30, 3], 'input'),
InputVar(tf.int32, [None], 'label') InputDesc(tf.int32, [None], 'label')
] ]
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -25,7 +25,7 @@ Usage: ...@@ -25,7 +25,7 @@ Usage:
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 227, 227, 3), 'input')] return [InputDesc(tf.float32, (None, 227, 227, 3), 'input')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
# img: 227x227x3 # img: 227x227x3
......
...@@ -25,7 +25,7 @@ Usage: ...@@ -25,7 +25,7 @@ Usage:
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, 224, 224, 3), 'input')] return [InputDesc(tf.float32, (None, 224, 224, 3), 'input')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image = inputs[0] image = inputs[0]
......
...@@ -26,8 +26,8 @@ class Model(ModelDesc): ...@@ -26,8 +26,8 @@ class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
"""Define all the input variables (with type, shape, name) that'll be """Define all the input variables (with type, shape, name) that'll be
fed into the graph to produce a cost. """ fed into the graph to produce a cost. """
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputVar(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
"""This function should build the model which takes the input variables """This function should build the model which takes the input variables
......
...@@ -23,8 +23,8 @@ Speed is about 43 it/s on TitanX. ...@@ -23,8 +23,8 @@ Speed is about 43 it/s on TitanX.
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, [None, 40, 40, 3], 'input'), return [InputDesc(tf.float32, [None, 40, 40, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -25,7 +25,7 @@ with tf.Graph().as_default() as G: ...@@ -25,7 +25,7 @@ with tf.Graph().as_default() as G:
MODEL = imp.load_source('config_script', args.config).Model MODEL = imp.load_source('config_script', args.config).Model
M = MODEL() M = MODEL()
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
M.build_graph(M.get_input_vars()) M.build_graph(M.get_reused_placehdrs())
else: else:
M = ModelFromMetaGraph(args.meta) M = ModelFromMetaGraph(args.meta)
......
...@@ -158,7 +158,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -158,7 +158,7 @@ class FeedfreeInferenceRunner(Triggerable):
Args: Args:
input (FeedfreeInput): the input to use. Must have ``size()``. input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names of InputVar. input_names (list): must be a subset of the names in InputDesc.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
...@@ -211,7 +211,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -211,7 +211,7 @@ class FeedfreeInferenceRunner(Triggerable):
break break
else: else:
raise ValueError( raise ValueError(
"{} doesn't appear in the InputVar of the model!".format(n)) "{} doesn't appear in the InputDesc of the model!".format(n))
self._input_tensors = model_placehdrs self._input_tensors = model_placehdrs
assert len(self._input_tensors) == len(model_placehdrs), \ assert len(self._input_tensors) == len(model_placehdrs), \
......
...@@ -41,6 +41,7 @@ class InputDesc(object): ...@@ -41,6 +41,7 @@ class InputDesc(object):
return pickle.loads(buf) return pickle.loads(buf)
# TODO print warning?
InputVar = InputDesc InputVar = InputDesc
...@@ -62,6 +63,7 @@ class ModelDesc(object): ...@@ -62,6 +63,7 @@ class ModelDesc(object):
return ret return ret
def get_input_vars(self): def get_input_vars(self):
# this wasn't a public API anyway
logger.warn("[Deprecated] get_input_vars() was renamed to get_reused_placehdrs()!") logger.warn("[Deprecated] get_input_vars() was renamed to get_reused_placehdrs()!")
return self.get_reused_placehdrs() return self.get_reused_placehdrs()
...@@ -95,6 +97,7 @@ class ModelDesc(object): ...@@ -95,6 +97,7 @@ class ModelDesc(object):
""" """
:returns: a list of InputDesc :returns: a list of InputDesc
""" """
# TODO deprecate @ Mar 11
logger.warn("[Deprecated] _get_input_vars() is renamed to _get_inputs()") logger.warn("[Deprecated] _get_input_vars() is renamed to _get_inputs()")
return self._get_input_vars() return self._get_input_vars()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment