Commit 96255c9a authored by Yuxin Wu's avatar Yuxin Wu

model multiple inputs as a list

parent b506eb0a
......@@ -92,18 +92,18 @@ def get_config():
dataset_train = dataset.Cifar10('train')
augmentors = [
RandomCrop((24, 24)),
Flip(horiz=True),
BrightnessAdd(63),
Contrast((0.2,1.8)),
MeanVarianceNormalize(all_channel=True)
imgaug.RandomCrop((24, 24)),
imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(63),
imgaug.Contrast((0.2,1.8)),
imgaug.MeanVarianceNormalize(all_channel=True)
]
dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128)
augmentors = [
CenterCrop((24, 24)),
MeanVarianceNormalize(all_channel=True)
imgaug.CenterCrop((24, 24)),
imgaug.MeanVarianceNormalize(all_channel=True)
]
dataset_test = dataset.Cifar10('test')
dataset_test = AugmentImageComponent(dataset_test, augmentors)
......
......@@ -7,6 +7,7 @@ import cPickle
import numpy
from six.moves import urllib
import tarfile
import logging
from ...utils import logger
from ..base import DataFlow
......@@ -24,10 +25,11 @@ def maybe_download_and_extract(dest_directory):
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')):
logger.info("Found cifar10 data in {}.".format(dest_directory))
return
else:
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filepath,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, reporthook=_progress)
......
......@@ -24,7 +24,7 @@ def maybe_download(filename, work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
logger.info("Downloading mnist data...")
logger.info("Downloading mnist data to {}...".format(filepath))
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
logger.info('Successfully downloaded to ' + filename)
......
......@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride)
if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=4e-2)
W_init = tf.truncated_normal_initializer(stddev=1e-2)
if b_init is None:
b_init = tf.constant_initializer()
......
......@@ -17,7 +17,7 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
in_dim = x.get_shape().as_list()[1]
if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(in_dim)))
W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
if b_init is None:
b_init = tf.constant_initializer(0.0)
......
......@@ -39,21 +39,24 @@ def sample(img, coords):
return sampled
@layer_register()
def ImageSample(template, mapping):
def ImageSample(inputs):
"""
Sample the template image, using the given coordinate, by bilinear interpolation.
inputs: list of [template, mapping]
template: bxhxwxc
mapping: bxh2xw2x2 (y, x) real-value coordinates
Return: bxh2xw2xc
"""
template, mapping = inputs
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
mapping = tf.maximum(mapping, 0.0)
tf.check_numerics(mapping, "mapping")
lcoor = tf.cast(mapping, tf.int32) # floor
ucoor = lcoor + 1
# has to cast to int32 and then cast back
# XXX tf.floor have gradient 1 w.r.t input, bug or feature?
# tf.floor have gradient 1 w.r.t input
# TODO bug fixed in #951
diff = mapping - tf.cast(lcoor, tf.float32)
neg_diff = 1.0 - diff #bxh2xw2x2
......@@ -128,7 +131,7 @@ if __name__ == '__main__':
mapping[0,y,x,:] = np.array([y-diff+0.4, x-diff+0.5])
mapv = tf.Variable(mapping)
output = ImageSample('sample', imv, mapv)
output = ImageSample('sample', [imv, mapv])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
......
......@@ -23,9 +23,9 @@ def describe_model():
def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, list):
if isinstance(tensors, (list, tuple)):
shape_str = ",".join(
map(str(x.get_shape().as_list()), tensors))
map(lambda x: str(x.get_shape().as_list()), tensors))
else:
shape_str = str(tensors.get_shape().as_list())
return shape_str
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment