Commit 9870d216 authored by Yuxin Wu's avatar Yuxin Wu

dynamic shape in deconv2d

parent 5aaf6410
......@@ -2,15 +2,15 @@
Neural Network Toolbox on TensorFlow
See some [examples](examples) to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code.
You can train them and reproduce the performance... not just to see how to write code.
+ [DoReFa-Net: training binary / low bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [ResNet for ImageNet/Cifar10/SVHN classification](examples/ResNet)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Generative Adversarial Networks(GAN) variants](examples/GAN)
+ [DQN variants on Atari games](examples/Atari2600)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED)
+ [Spatial Transformer Network on MNIST addition](examples/SpatialTransformer)
+ [Generative Adversarial Network(GAN) variants](examples/GAN)
+ [Deep Q-Network(DQN) variants on Atari games](examples/Atari2600)
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/OpenAIGym)
+ [char-rnn language model](examples/char-rnn)
......@@ -20,13 +20,13 @@ Describe your training task with three components:
1. __Model__, or graph. `models/` has some scoped abstraction of common models, but you can simply use
any symbolic functions available in tensorflow, or most functions in slim/tflearn/tensorlayer.
`LinearWrap` and `argscope` makes large models look simpler ([vgg example](https://github.com/ppwwyyxx/tensorpack/blob/master/examples/load-vgg16.py)).
`LinearWrap` and `argscope` simplify large models ([vgg example](https://github.com/ppwwyyxx/tensorpack/blob/master/examples/load-vgg16.py)).
2. __DataFlow__. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `generator` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle any data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
For example, InceptionV3 can run in the same speed as the official code which reads data using TF operators.
+ Use Python to easily handle any data format, yet still keep good performance thanks to multiprocess prefetch & TF Queue prefetch.
For example, InceptionV3 can run in the same speed as the official code which reads data by TF operators.
3. __Callbacks__, including everything you want to do apart from the training iterations, such as:
+ Change hyperparameters during training
......
......@@ -126,30 +126,29 @@ def sample(model_path):
o = o[:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=10, nr_col=10, viz=True))
def vec(model_path):
func = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
input_names=['z'],
output_names=['gen/gen']))
dic = np.load('demo/CelebA-vec.npy').item()
assert np.all(
dic['w_smile'] - dic['w_neutral'] \
+ dic['m_neutral'] == dic['m_smile'])
imgs = []
for z in ['w_neutral', 'w_smile', 'm_neutral', 'm_smile']:
z = dic[z]
img = func([[z]])[0][0][:,:,::-1]
img = (img + 1) * 128
imgs.append(img)
viz = next(build_patch_list(imgs, nr_row=1, nr_col=4, viz=True))
#def vec(model_path):
#func = OfflinePredictor(PredictConfig(
#session_init=get_model_loader(model_path),
#model=Model(),
#input_names=['z'],
#output_names=['gen/gen']))
#dic = np.load('demo/CelebA-vec.npy').item()
#assert np.all(
#dic['w_smile'] - dic['w_neutral'] \
#+ dic['m_neutral'] == dic['m_smile'])
#imgs = []
#for z in ['w_neutral', 'w_smile', 'm_neutral', 'm_smile']:
#z = dic[z]
#img = func([[z]])[0][0][:,:,::-1]
#img = (img + 1) * 128
#imgs.append(img)
#viz = next(build_patch_list(imgs, nr_row=1, nr_col=4, viz=True))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--vec', action='store_true', help='run vec arithmetic demo')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
args = parser.parse_args()
use_global_argument(args)
......@@ -157,8 +156,6 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
sample(args.load)
elif args.vec:
vec(args.load)
else:
assert args.data
config = get_config()
......
......@@ -3,6 +3,7 @@
# File: load-alexnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from __future__ import print_function
import tensorflow as tf
import numpy as np
import os, cv2, argparse
......
......@@ -21,7 +21,7 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size
def get_data(self):
with get_tqdm(total=range(self.test_size)) as pbar:
with get_tqdm(total=self.test_size) as pbar:
for dp in self.ds.get_data():
pbar.update()
for dp in self.ds.get_data():
......
......@@ -34,7 +34,7 @@ def Conv2D(x, out_channel, kernel_shape,
"""
in_shape = x.get_shape().as_list()
in_channel = in_shape[-1]
assert in_channel is not None, "Input to Conv2D cannot have unknown channel!"
assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
assert in_channel % split == 0
assert out_channel % split == 0
......@@ -65,6 +65,26 @@ def Conv2D(x, out_channel, kernel_shape,
nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
class StaticDynamicShape(object):
def __init__(self, static, dynamic):
self.static = static
self.dynamic = dynamic
def apply_dynamic(self, f):
try:
return f(self.static)
except:
return f(self.dynamic)
def apply_static(self, f):
try:
return f(self.static)
except:
return None
def apply(self, f):
return StaticDynamicShape(self.apply_static(f), self.apply_dynamic(f))
@layer_register()
def Deconv2D(x, out_shape, kernel_shape,
stride, padding='SAME',
......@@ -86,8 +106,8 @@ def Deconv2D(x, out_shape, kernel_shape,
:returns: a NHWC tensor
"""
in_shape = x.get_shape().as_list()[1:]
assert None not in in_shape, "Input to Deconv2D cannot have unknown shape!"
in_channel = in_shape[-1]
assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
kernel_shape = shape2d(kernel_shape)
stride2d = shape2d(stride)
stride4d = shape4d(stride)
......@@ -95,10 +115,16 @@ def Deconv2D(x, out_shape, kernel_shape,
if isinstance(out_shape, int):
out_channel = out_shape
shape3 = [stride2d[0] * in_shape[0], stride2d[1] * in_shape[1], out_shape]
shp3_0 = StaticDynamicShape(in_shape[0], tf.shape(x)[1]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicShape(in_shape[1], tf.shape(x)[2]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [shp3_0.dynamic, shp3_1.dynamic, out_channel]
shp3_static = [shp3_0.static, shp3_1.static, out_channel]
else:
for k in out_shape:
if not isinstance(k, int):
raise ValueError("[Deconv2D] out_shape is invalid!")
out_channel = out_shape[-1]
shape3 = out_shape
shp3_static = shp3_dyn = out_shape
filter_shape = kernel_shape + [out_channel, in_channel]
if W_init is None:
......@@ -109,7 +135,7 @@ def Deconv2D(x, out_shape, kernel_shape,
if use_bias:
b = tf.get_variable('b', [out_channel], initializer=b_init)
out_shape = tf.pack([tf.shape(x)[0]] + shape3)
conv = tf.nn.conv2d_transpose(x, W, out_shape, stride4d, padding=padding)
conv.set_shape(tf.TensorShape([None] + shape3))
out_shape_dyn = tf.pack([tf.shape(x)[0]] + shp3_dyn)
conv = tf.nn.conv2d_transpose(x, W, out_shape_dyn, stride4d, padding=padding)
conv.set_shape(tf.TensorShape([None] + shp3_static))
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
......@@ -194,6 +194,7 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name
:return: either a ParamRestore or SaverRestore
"""
assert os.path.isfile(filename), filename
if filename.endswith('.npy'):
return ParamRestore(np.load(filename, encoding='latin1').item())
else:
......
......@@ -144,7 +144,7 @@ def build_patch_list(patch_list,
start = end
def dump_dataflow_images(df, index=0, batched=True,
number=300, output_dir=None,
number=1000, output_dir=None,
scale=1, resize=None, viz=None,
flipRGB=False, exit_after=True):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment