Commit a50bb749 authored by Yuxin Wu's avatar Yuxin Wu

update readme

parent 05b18a47
...@@ -4,8 +4,8 @@ Neural Network Toolbox on TensorFlow ...@@ -4,8 +4,8 @@ Neural Network Toolbox on TensorFlow
See some [examples](examples) to learn about the framework: See some [examples](examples) to learn about the framework:
### Vision: ### Vision:
+ [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net) + [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [ResNet for ImageNet/Cifar10/SVHN](examples/ResNet) + [Train ResNet on ImageNet/Cifar10/SVHN](examples/ResNet)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py) + [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
......
...@@ -29,10 +29,12 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv ...@@ -29,10 +29,12 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv
Image-to-Image following the setup in [pix2pix](https://github.com/phillipi/pix2pix). Image-to-Image following the setup in [pix2pix](https://github.com/phillipi/pix2pix).
It requires the datasets released by the original authors. It requires the datasets released by the original authors.
With the cityscapes dataset, it learns to generate semantic segmentation map of urban scene: For example, with the cityscapes dataset, it learns to generate semantic segmentation map of urban scene:
![im2im](demo/im2im-cityscapes.jpg) ![im2im](demo/im2im-cityscapes.jpg)
This is a visualization from tensorboard. Left to right: original, ground truth, model output.
## InfoGAN-mnist.py ## InfoGAN-mnist.py
Reproduce one mnist experiement in InfoGAN. Reproduce one mnist experiement in InfoGAN.
......
## imagenet-resnet.py ## imagenet-resnet.py
Training code of pre-activation ResNet on ImageNet. It follows the setup in __Training__ code of pre-activation ResNet on ImageNet. It follows the setup in
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code). [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code).
Models can be [downloaded here](https://goo.gl/6XjK9V). Models can be [downloaded here](https://goo.gl/6XjK9V).
......
...@@ -124,6 +124,19 @@ def BilinearUpSample(x, shape): ...@@ -124,6 +124,19 @@ def BilinearUpSample(x, shape):
:param x: input NHWC tensor :param x: input NHWC tensor
:param shape: an integer, the upsample factor :param shape: an integer, the upsample factor
""" """
#inp_shape = tf.shape(x)
#return tf.image.resize_bilinear(x,
#tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#align_corners=True)
inp_shape = x.get_shape().as_list()
ch = inp_shape[3]
assert ch is not None
shape = int(shape)
filter_shape = 2 * shape
def bilinear_conv_filler(s): def bilinear_conv_filler(s):
""" """
s: width, height of the conv filter s: width, height of the conv filter
...@@ -136,13 +149,6 @@ def BilinearUpSample(x, shape): ...@@ -136,13 +149,6 @@ def BilinearUpSample(x, shape):
for y in range(s): for y in range(s):
ret[x,y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) ret[x,y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
return ret return ret
inp_shape = x.get_shape().as_list()
ch = inp_shape[3]
assert ch is not None
shape = int(shape)
filter_shape = 2 * shape
w = bilinear_conv_filler(filter_shape) w = bilinear_conv_filler(filter_shape)
w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch)) w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch))
weight_var = tf.constant(w, tf.float32, weight_var = tf.constant(w, tf.float32,
......
...@@ -168,7 +168,6 @@ class ParamRestore(SessionInit): ...@@ -168,7 +168,6 @@ class ParamRestore(SessionInit):
for k in param_names - variable_names: for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in the graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, upd = SessionUpdate(sess,
[v for v in variables if \ [v for v in variables if \
get_savename_from_varname(v.name) in intersect]) get_savename_from_varname(v.name) in intersect])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment