Commit 96255c9a authored by Yuxin Wu's avatar Yuxin Wu

model multiple inputs as a list

parent b506eb0a
...@@ -92,18 +92,18 @@ def get_config(): ...@@ -92,18 +92,18 @@ def get_config():
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
augmentors = [ augmentors = [
RandomCrop((24, 24)), imgaug.RandomCrop((24, 24)),
Flip(horiz=True), imgaug.Flip(horiz=True),
BrightnessAdd(63), imgaug.BrightnessAdd(63),
Contrast((0.2,1.8)), imgaug.Contrast((0.2,1.8)),
MeanVarianceNormalize(all_channel=True) imgaug.MeanVarianceNormalize(all_channel=True)
] ]
dataset_train = AugmentImageComponent(dataset_train, augmentors) dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
augmentors = [ augmentors = [
CenterCrop((24, 24)), imgaug.CenterCrop((24, 24)),
MeanVarianceNormalize(all_channel=True) imgaug.MeanVarianceNormalize(all_channel=True)
] ]
dataset_test = dataset.Cifar10('test') dataset_test = dataset.Cifar10('test')
dataset_test = AugmentImageComponent(dataset_test, augmentors) dataset_test = AugmentImageComponent(dataset_test, augmentors)
......
...@@ -7,6 +7,7 @@ import cPickle ...@@ -7,6 +7,7 @@ import cPickle
import numpy import numpy
from six.moves import urllib from six.moves import urllib
import tarfile import tarfile
import logging
from ...utils import logger from ...utils import logger
from ..base import DataFlow from ..base import DataFlow
...@@ -24,10 +25,11 @@ def maybe_download_and_extract(dest_directory): ...@@ -24,10 +25,11 @@ def maybe_download_and_extract(dest_directory):
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename) filepath = os.path.join(dest_directory, filename)
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')): if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')):
logger.info("Found cifar10 data in {}.".format(dest_directory))
return return
else: else:
def _progress(count, block_size, total_size): def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, sys.stdout.write('\r>> Downloading %s %.1f%%' % (filepath,
float(count * block_size) / float(total_size) * 100.0)) float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush() sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, reporthook=_progress) filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, reporthook=_progress)
......
...@@ -24,7 +24,7 @@ def maybe_download(filename, work_directory): ...@@ -24,7 +24,7 @@ def maybe_download(filename, work_directory):
os.mkdir(work_directory) os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename) filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
logger.info("Downloading mnist data...") logger.info("Downloading mnist data to {}...".format(filepath))
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath) statinfo = os.stat(filepath)
logger.info('Successfully downloaded to ' + filename) logger.info('Successfully downloaded to ' + filename)
......
...@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride) stride = shape4d(stride)
if W_init is None: if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=4e-2) W_init = tf.truncated_normal_initializer(stddev=1e-2)
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
...@@ -17,7 +17,7 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu): ...@@ -17,7 +17,7 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
in_dim = x.get_shape().as_list()[1] in_dim = x.get_shape().as_list()[1]
if W_init is None: if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(in_dim))) W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
if b_init is None: if b_init is None:
b_init = tf.constant_initializer(0.0) b_init = tf.constant_initializer(0.0)
......
...@@ -39,21 +39,24 @@ def sample(img, coords): ...@@ -39,21 +39,24 @@ def sample(img, coords):
return sampled return sampled
@layer_register() @layer_register()
def ImageSample(template, mapping): def ImageSample(inputs):
""" """
Sample the template image, using the given coordinate, by bilinear interpolation. Sample the template image, using the given coordinate, by bilinear interpolation.
inputs: list of [template, mapping]
template: bxhxwxc template: bxhxwxc
mapping: bxh2xw2x2 (y, x) real-value coordinates mapping: bxh2xw2x2 (y, x) real-value coordinates
Return: bxh2xw2xc Return: bxh2xw2xc
""" """
template, mapping = inputs
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4 assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
mapping = tf.maximum(mapping, 0.0) mapping = tf.maximum(mapping, 0.0)
tf.check_numerics(mapping, "mapping")
lcoor = tf.cast(mapping, tf.int32) # floor lcoor = tf.cast(mapping, tf.int32) # floor
ucoor = lcoor + 1 ucoor = lcoor + 1
# has to cast to int32 and then cast back # has to cast to int32 and then cast back
# XXX tf.floor have gradient 1 w.r.t input, bug or feature? # tf.floor have gradient 1 w.r.t input
# TODO bug fixed in #951
diff = mapping - tf.cast(lcoor, tf.float32) diff = mapping - tf.cast(lcoor, tf.float32)
neg_diff = 1.0 - diff #bxh2xw2x2 neg_diff = 1.0 - diff #bxh2xw2x2
...@@ -128,7 +131,7 @@ if __name__ == '__main__': ...@@ -128,7 +131,7 @@ if __name__ == '__main__':
mapping[0,y,x,:] = np.array([y-diff+0.4, x-diff+0.5]) mapping[0,y,x,:] = np.array([y-diff+0.4, x-diff+0.5])
mapv = tf.Variable(mapping) mapv = tf.Variable(mapping)
output = ImageSample('sample', imv, mapv) output = ImageSample('sample', [imv, mapv])
sess = tf.Session() sess = tf.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
......
...@@ -23,9 +23,9 @@ def describe_model(): ...@@ -23,9 +23,9 @@ def describe_model():
def get_shape_str(tensors): def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors""" """ return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, list): if isinstance(tensors, (list, tuple)):
shape_str = ",".join( shape_str = ",".join(
map(str(x.get_shape().as_list()), tensors)) map(lambda x: str(x.get_shape().as_list()), tensors))
else: else:
shape_str = str(tensors.get_shape().as_list()) shape_str = str(tensors.get_shape().as_list())
return shape_str return shape_str
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment