Commit ae9627cf authored by Yuxin Wu's avatar Yuxin Wu

sed -i 's/InputDesc/tf.placeholder/g;s/_get_inputs/inputs/g' (#318)

parent 39fa4656
...@@ -69,12 +69,12 @@ class MySimulatorWorker(SimulatorProcess): ...@@ -69,12 +69,12 @@ class MySimulatorWorker(SimulatorProcess):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputDesc(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'), return [tf.placeholder(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'),
InputDesc(tf.int64, (None,), 'action'), tf.placeholder(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'futurereward'), tf.placeholder(tf.float32, (None,), 'futurereward'),
InputDesc(tf.float32, (None,), 'action_prob'), tf.placeholder(tf.float32, (None,), 'action_prob'),
] ]
def _get_NN_prediction(self, image): def _get_NN_prediction(self, image):
......
...@@ -25,12 +25,12 @@ FEATUREDIM = 39 # MFCC feature dimension ...@@ -25,12 +25,12 @@ FEATUREDIM = 39 # MFCC feature dimension
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, None, FEATUREDIM], 'feat'), # bxmaxseqx39 return [tf.placeholder(tf.float32, [None, None, FEATUREDIM], 'feat'), # bxmaxseqx39
InputDesc(tf.int64, [None, None], 'labelidx'), # label is b x maxlen, sparse tf.placeholder(tf.int64, [None, None], 'labelidx'), # label is b x maxlen, sparse
InputDesc(tf.int32, [None], 'labelvalue'), tf.placeholder(tf.int32, [None], 'labelvalue'),
InputDesc(tf.int64, [None], 'labelshape'), tf.placeholder(tf.int64, [None], 'labelshape'),
InputDesc(tf.int32, [None], 'seqlen'), # b tf.placeholder(tf.int32, [None], 'seqlen'), # b
] ]
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -70,9 +70,9 @@ class CharRNNData(RNGDataFlow): ...@@ -70,9 +70,9 @@ class CharRNNData(RNGDataFlow):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.int32, (None, param.seq_len), 'input'), return [tf.placeholder(tf.int32, (None, param.seq_len), 'input'),
InputDesc(tf.int32, (None, param.seq_len), 'nextinput')] tf.placeholder(tf.int32, (None, param.seq_len), 'nextinput')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
input, nextinput = inputs input, nextinput = inputs
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import abc import abc
import tensorflow as tf import tensorflow as tf
import tensorpack import tensorpack
from tensorpack import ModelDesc, InputDesc from tensorpack import ModelDesc
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.tfutils import ( from tensorpack.tfutils import (
varreplace, summary, get_current_tower_context, optimizer, gradproc) varreplace, summary, get_current_tower_context, optimizer, gradproc)
...@@ -24,15 +24,15 @@ class Model(ModelDesc): ...@@ -24,15 +24,15 @@ class Model(ModelDesc):
self.num_actions = num_actions self.num_actions = num_actions
self.gamma = gamma self.gamma = gamma
def _get_inputs(self): def inputs(self):
# Use a combined state for efficiency. # Use a combined state for efficiency.
# The first h channels are the current state, and the last h channels are the next state. # The first h channels are the current state, and the last h channels are the next state.
return [InputDesc(tf.uint8, return [tf.placeholder(tf.uint8,
(None,) + self.image_shape + (self.channel + 1,), (None,) + self.image_shape + (self.channel + 1,),
'comb_state'), 'comb_state'),
InputDesc(tf.int64, (None,), 'action'), tf.placeholder(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'reward'), tf.placeholder(tf.float32, (None,), 'reward'),
InputDesc(tf.bool, (None,), 'isOver')] tf.placeholder(tf.bool, (None,), 'isOver')]
@abc.abstractmethod @abc.abstractmethod
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
......
...@@ -77,9 +77,9 @@ BATCH_SIZE = None ...@@ -77,9 +77,9 @@ BATCH_SIZE = None
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 224, 224, 3], 'input'), return [tf.placeholder(tf.float32, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -32,9 +32,9 @@ BITG = 32 ...@@ -32,9 +32,9 @@ BITG = 32
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 224, 224, 3], 'input'), return [tf.placeholder(tf.float32, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -43,9 +43,9 @@ BITG = 4 ...@@ -43,9 +43,9 @@ BITG = 4
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 40, 40, 3], 'input'), return [tf.placeholder(tf.float32, [None, 40, 40, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -95,11 +95,11 @@ class OnlineTensorboardExport(Callback): ...@@ -95,11 +95,11 @@ class OnlineTensorboardExport(Callback):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (BATCH, ), 'theta'), return [tf.placeholder(tf.float32, (BATCH, ), 'theta'),
InputDesc(tf.float32, (BATCH, SHAPE, SHAPE), 'image'), tf.placeholder(tf.float32, (BATCH, SHAPE, SHAPE), 'image'),
InputDesc(tf.float32, (BATCH, SHAPE, SHAPE), 'gt_image'), tf.placeholder(tf.float32, (BATCH, SHAPE, SHAPE), 'gt_image'),
InputDesc(tf.float32, (BATCH, 9, 9), 'gt_filter')] tf.placeholder(tf.float32, (BATCH, 9, 9), 'gt_filter')]
def _parameter_net(self, theta, kernel_shape=9): def _parameter_net(self, theta, kernel_shape=9):
"""Estimate filters for convolution layers """Estimate filters for convolution layers
......
...@@ -57,16 +57,16 @@ def get_model_output_names(): ...@@ -57,16 +57,16 @@ def get_model_output_names():
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
ret = [ ret = [
InputDesc(tf.float32, (None, None, 3), 'image'), tf.placeholder(tf.float32, (None, None, 3), 'image'),
InputDesc(tf.int32, (None, None, config.NUM_ANCHOR), 'anchor_labels'), tf.placeholder(tf.int32, (None, None, config.NUM_ANCHOR), 'anchor_labels'),
InputDesc(tf.float32, (None, None, config.NUM_ANCHOR, 4), 'anchor_boxes'), tf.placeholder(tf.float32, (None, None, config.NUM_ANCHOR, 4), 'anchor_boxes'),
InputDesc(tf.float32, (None, 4), 'gt_boxes'), tf.placeholder(tf.float32, (None, 4), 'gt_boxes'),
InputDesc(tf.int64, (None,), 'gt_labels')] # all > 0 tf.placeholder(tf.int64, (None,), 'gt_labels')] # all > 0
if config.MODE_MASK: if config.MODE_MASK:
ret.append( ret.append(
InputDesc(tf.uint8, (None, None, None), 'gt_masks') tf.placeholder(tf.uint8, (None, None, None), 'gt_masks')
) # NR_GT x height x width ) # NR_GT x height x width
return ret return ret
......
...@@ -26,8 +26,8 @@ GAMMA = 0.5 ...@@ -26,8 +26,8 @@ GAMMA = 0.5
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, args.final_size, args.final_size, 3), 'input')] return [tf.placeholder(tf.float32, (None, args.final_size, args.final_size, 3), 'input')]
@auto_reuse_variable_scope @auto_reuse_variable_scope
def decoder(self, z): def decoder(self, z):
......
...@@ -40,9 +40,9 @@ def batch_flatten(x): ...@@ -40,9 +40,9 @@ def batch_flatten(x):
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, 28, 28), 'input'), return [tf.placeholder(tf.float32, (None, 28, 28), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def generator(self, z, y): def generator(self, z, y):
l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, activation=BNReLU) l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, activation=BNReLU)
......
...@@ -41,9 +41,9 @@ def INLReLU(x, name=None): ...@@ -41,9 +41,9 @@ def INLReLU(x, name=None):
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'), return [tf.placeholder(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')] tf.placeholder(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')]
@staticmethod @staticmethod
def build_res_block(x, name, chan, first=False): def build_res_block(x, name, chan, first=False):
......
...@@ -40,8 +40,8 @@ class Model(GANModelDesc): ...@@ -40,8 +40,8 @@ class Model(GANModelDesc):
self.batch = batch self.batch = batch
self.zdim = z_dim self.zdim = z_dim
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, self.shape, self.shape, 3), 'input')] return [tf.placeholder(tf.float32, (None, self.shape, self.shape, 3), 'input')]
def generator(self, z): def generator(self, z):
""" return an image generated from z""" """ return an image generated from z"""
......
...@@ -34,9 +34,9 @@ def BNLReLU(x, name=None): ...@@ -34,9 +34,9 @@ def BNLReLU(x, name=None):
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'), return [tf.placeholder(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')] tf.placeholder(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')]
@auto_reuse_variable_scope @auto_reuse_variable_scope
def generator(self, img): def generator(self, img):
......
...@@ -63,10 +63,10 @@ def visualize_tensors(name, imgs, scale_func=lambda x: (x + 1.) * 128., max_outp ...@@ -63,10 +63,10 @@ def visualize_tensors(name, imgs, scale_func=lambda x: (x + 1.) * 128., max_outp
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def inputs(self):
SHAPE = 256 SHAPE = 256
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'), return [tf.placeholder(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')] tf.placeholder(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')]
def generator(self, imgs): def generator(self, imgs):
# imgs: input: 256x256xch # imgs: input: 256x256xch
......
...@@ -105,8 +105,8 @@ def sample_prior(batch_size): ...@@ -105,8 +105,8 @@ def sample_prior(batch_size):
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, 28, 28), 'input')] return [tf.placeholder(tf.float32, (None, 28, 28), 'input')]
def generator(self, z): def generator(self, z):
l = FullyConnected('fc0', z, 1024, activation=BNReLU) l = FullyConnected('fc0', z, 1024, activation=BNReLU)
......
...@@ -44,9 +44,9 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss ...@@ -44,9 +44,9 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, None, None, 3], 'image'), return [tf.placeholder(tf.float32, [None, None, None, 3], 'image'),
InputDesc(tf.int32, [None, None, None], 'edgemap')] tf.placeholder(tf.int32, [None, None, None], 'edgemap')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, edgemap = inputs image, edgemap = inputs
......
...@@ -9,7 +9,7 @@ import multiprocessing ...@@ -9,7 +9,7 @@ import multiprocessing
import tensorflow as tf import tensorflow as tf
from abc import abstractmethod from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc, InputDesc from tensorpack import imgaug, dataset, ModelDesc
from tensorpack.dataflow import ( from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ, AugmentImageComponent, PrefetchDataZMQ,
BatchData, MultiThreadMapData) BatchData, MultiThreadMapData)
...@@ -148,9 +148,9 @@ class ImageNetModel(ModelDesc): ...@@ -148,9 +148,9 @@ class ImageNetModel(ModelDesc):
def __init__(self, data_format='NCHW'): def __init__(self, data_format='NCHW'):
self.data_format = data_format self.data_format = data_format
def _get_inputs(self): def inputs(self):
return [InputDesc(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'), return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -24,9 +24,9 @@ INPUT_SHAPE = 224 ...@@ -24,9 +24,9 @@ INPUT_SHAPE = 224
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [tf.placeholder(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -46,9 +46,9 @@ def get_PennTreeBank(data_dir=None): ...@@ -46,9 +46,9 @@ def get_PennTreeBank(data_dir=None):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.int32, (None, SEQ_LEN), 'input'), return [tf.placeholder(tf.int32, (None, SEQ_LEN), 'input'),
InputDesc(tf.int32, (None, SEQ_LEN), 'nextinput')] tf.placeholder(tf.int32, (None, SEQ_LEN), 'nextinput')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
......
...@@ -39,9 +39,9 @@ def preactivation_block(input, num_filters, stride=1): ...@@ -39,9 +39,9 @@ def preactivation_block(input, num_filters, stride=1):
class ResNet_Cifar(ModelDesc): class ResNet_Cifar(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 32, 32, 3], 'input'), return [tf.placeholder(tf.float32, [None, 32, 32, 3], 'input'),
InputDesc(tf.float32, [None, CLASS_NUM], 'label')] tf.placeholder(tf.float32, [None, CLASS_NUM], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
......
...@@ -40,9 +40,9 @@ class Model(ModelDesc): ...@@ -40,9 +40,9 @@ class Model(ModelDesc):
super(Model, self).__init__() super(Model, self).__init__()
self.n = n self.n = n
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 32, 32, 3], 'input'), return [tf.placeholder(tf.float32, [None, 32, 32, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -28,9 +28,9 @@ CFG = { ...@@ -28,9 +28,9 @@ CFG = {
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 224, 224, 3], 'input'), return [tf.placeholder(tf.float32, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -31,9 +31,9 @@ DEPTH = None ...@@ -31,9 +31,9 @@ DEPTH = None
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [tf.placeholder(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -54,8 +54,8 @@ def saliency_map(output, input, name="saliency_map"): ...@@ -54,8 +54,8 @@ def saliency_map(output, input, name="saliency_map"):
class Model(tp.ModelDesc): class Model(tp.ModelDesc):
def _get_inputs(self): def inputs(self):
return [tp.InputDesc(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')] return [tf.placeholder(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
orig_image = inputs[0] orig_image = inputs[0]
......
...@@ -236,10 +236,10 @@ class SiameseModel(EmbeddingModel): ...@@ -236,10 +236,10 @@ class SiameseModel(EmbeddingModel):
ds = BatchData(ds, 128 // 2) ds = BatchData(ds, 128 // 2)
return ds return ds
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, 28, 28), 'input'), return [tf.placeholder(tf.float32, (None, 28, 28), 'input'),
InputDesc(tf.float32, (None, 28, 28), 'input_y'), tf.placeholder(tf.float32, (None, 28, 28), 'input_y'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
# get inputs # get inputs
...@@ -279,10 +279,10 @@ class TripletModel(EmbeddingModel): ...@@ -279,10 +279,10 @@ class TripletModel(EmbeddingModel):
ds = BatchData(ds, 128 // 3) ds = BatchData(ds, 128 // 3)
return ds return ds
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, 28, 28), 'input'), return [tf.placeholder(tf.float32, (None, 28, 28), 'input'),
InputDesc(tf.float32, (None, 28, 28), 'input_p'), tf.placeholder(tf.float32, (None, 28, 28), 'input_p'),
InputDesc(tf.float32, (None, 28, 28), 'input_n')] tf.placeholder(tf.float32, (None, 28, 28), 'input_n')]
def loss(self, a, p, n): def loss(self, a, p, n):
return triplet_loss(a, p, n, 5., extra=True, scope="loss") return triplet_loss(a, p, n, 5., extra=True, scope="loss")
...@@ -312,9 +312,9 @@ class CenterModel(EmbeddingModel): ...@@ -312,9 +312,9 @@ class CenterModel(EmbeddingModel):
ds = BatchData(ds, 128) ds = BatchData(ds, 128)
return ds return ds
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, 28, 28), 'input'), return [tf.placeholder(tf.float32, (None, 28, 28), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
# get inputs # get inputs
......
...@@ -20,9 +20,9 @@ HALF_DIFF = (IMAGE_SIZE - WARP_TARGET_SIZE) // 2 ...@@ -20,9 +20,9 @@ HALF_DIFF = (IMAGE_SIZE - WARP_TARGET_SIZE) // 2
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, 2), 'input'), return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, 2), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
xys = np.array([(y, x, 1) for y in range(WARP_TARGET_SIZE) xys = np.array([(y, x, 1) for y in range(WARP_TARGET_SIZE)
......
...@@ -48,9 +48,9 @@ class Model(GANModelDesc): ...@@ -48,9 +48,9 @@ class Model(GANModelDesc):
self.height = height self.height = height
self.width = width self.width = width
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, self.height * 1, self.width * 1, CHANNELS), 'Ilr'), return [tf.placeholder(tf.float32, (None, self.height * 1, self.width * 1, CHANNELS), 'Ilr'),
InputDesc(tf.float32, (None, self.height * 4, self.width * 4, CHANNELS), 'Ihr')] tf.placeholder(tf.float32, (None, self.height * 4, self.width * 4, CHANNELS), 'Ihr')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
ctx = get_current_tower_context() ctx = get_current_tower_context()
......
...@@ -23,13 +23,13 @@ IMAGE_SIZE = 28 ...@@ -23,13 +23,13 @@ IMAGE_SIZE = 28
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
""" """
Define all the inputs (with type, shape, name) that Define all the inputs (with type, shape, name) that
the graph will need. the graph will need.
""" """
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
"""This function should build the model which takes the input variables """This function should build the model which takes the input variables
......
...@@ -22,9 +22,9 @@ IMAGE_SIZE = 28 ...@@ -22,9 +22,9 @@ IMAGE_SIZE = 28
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -68,9 +68,9 @@ def visualize_conv_activations(activation, name): ...@@ -68,9 +68,9 @@ def visualize_conv_activations(activation, name):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -22,9 +22,9 @@ Speed is about 43 it/s on TitanX. ...@@ -22,9 +22,9 @@ Speed is about 43 it/s on TitanX.
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, [None, 40, 40, 3], 'input'), return [tf.placeholder(tf.float32, [None, 40, 40, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] tf.placeholder(tf.int32, [None], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -19,9 +19,9 @@ CHANNELS = 3 ...@@ -19,9 +19,9 @@ CHANNELS = 3
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, CHANNELS), 'input'), return [tf.placeholder(tf.float32, (None, SHAPE, SHAPE, CHANNELS), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -40,9 +40,9 @@ def get_keras_model(): ...@@ -40,9 +40,9 @@ def get_keras_model():
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] tf.placeholder(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
......
...@@ -102,8 +102,10 @@ class ModelDescBase(object): ...@@ -102,8 +102,10 @@ class ModelDescBase(object):
try: try:
return self._get_inputs() return self._get_inputs()
except NotImplementedError: except NotImplementedError:
with tf.Graph().as_default(): # create these placeholder in a temporary graph with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs() inputs = self.inputs()
for p in inputs:
assert p.graph == G, "Placeholders returned by inputs() sholud be created inside inputs()!"
return [InputDesc.from_placeholder(p) for p in inputs] return [InputDesc.from_placeholder(p) for p in inputs]
def _get_inputs(self): def _get_inputs(self):
...@@ -117,7 +119,11 @@ class ModelDescBase(object): ...@@ -117,7 +119,11 @@ class ModelDescBase(object):
""" """
__Create__ and returns a list of placeholders. __Create__ and returns a list of placeholders.
To be implemented by subclass. To be implemented by subclass.
The placeholders __have to__ be created inside this function.
The placeholders __have to__ be created inside this method.
Don't return placeholders created in other methods.
You should not call this method by yourself.
Returns: Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`. a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment