Commit 6f28c94e authored by Yuxin Wu's avatar Yuxin Wu

bug fix in viz

parent 52a4a0a8
# tensorpack
Neural Network Toolbox on TensorFlow
Still in development. Underlying design may change.
See some [examples](examples) to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code.
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net)
+ [DoReFa-Net: training binary / low bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [ResNet for ImageNet/Cifar10/SVHN classification](examples/ResNet)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
......@@ -27,7 +25,7 @@ Describe your training task with three components:
2. Data. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `generator` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle any of your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
+ Use Python to easily handle any data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
For example, InceptionV3 can run in the same speed as the official code which reads data using TF operators.
3. Callbacks, including everything you want to do apart from the training iterations, such as:
......@@ -39,6 +37,7 @@ Describe your training task with three components:
With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is off-the-shelf by simply switching the trainer.
You can also define your own trainer for non-standard training (e.g. GAN).
## Dependencies:
......
......@@ -174,12 +174,12 @@ def sample(datadir, model_path):
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = BatchData(MapData(ds, lambda dp: split_input(dp[0])), 16)
ds = BatchData(MapData(ds, lambda dp: split_input(dp[0])), 6)
pred = SimpleDatasetPredictor(pred, ds)
for o in pred.get_result():
o = o[:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=4, nr_col=4, viz=True))
o = o[0][:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=3, nr_col=2, viz=True))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
......@@ -200,4 +200,3 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config, g_vs_d=1).train()
# Generative Adversarial Networks
See the docstring in the script for detailed usage.
## DCGAN-CelebA.py
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTvr2BBkLUF2M0RXU1NYSkE?usp=sharing) on CelebA face dataset:
1. Generated samples
+ Generated samples
![sample](demo/CelebA-samples.jpg)
2. Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man
+ Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man
![vec](demo/CelebA-vec.jpg)
See the docstring in the script for usage.
## Image2Image.py
Reproduce [Image-to-image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf),
......
......@@ -18,7 +18,7 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
class PredictorBase(object):
__metaclass__ = ABCMeta
"""
Property:
Available attributes:
session
return_input
"""
......
......@@ -78,7 +78,7 @@ def build_patch_list(patch_list,
"""
Generate patches.
:param patch_list: bhw or bhwc images in [0,255]
:param border: defaults to 0.1 * max(image_width, image_height)
:param border: defaults to 0.1 * min(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:param shuffle: shuffle the images
......@@ -97,7 +97,7 @@ def build_patch_list(patch_list,
viz = True
ph, pw = patch_list.shape[1:3]
if border is None:
border = int(0.1 * max(ph, pw))
border = int(0.1 * min(ph, pw))
mh, mw = max(max_height, ph + border), max(max_width, pw + border)
if nr_row is None:
nr_row = minnone(nr_row, max_height / (ph + border))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment