Commit 6f28c94e authored by Yuxin Wu's avatar Yuxin Wu

bug fix in viz

parent 52a4a0a8
# tensorpack # tensorpack
Neural Network Toolbox on TensorFlow Neural Network Toolbox on TensorFlow
Still in development. Underlying design may change.
See some [examples](examples) to learn about the framework. See some [examples](examples) to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code. You can actually train them and reproduce the performance... not just to see how to write code.
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net) + [DoReFa-Net: training binary / low bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [ResNet for ImageNet/Cifar10/SVHN classification](examples/ResNet) + [ResNet for ImageNet/Cifar10/SVHN classification](examples/ResNet)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py) + [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
...@@ -27,7 +25,7 @@ Describe your training task with three components: ...@@ -27,7 +25,7 @@ Describe your training task with three components:
2. Data. tensorpack allows and encourages complex data processing. 2. Data. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `generator` interface, allowing them to be composed to perform complex preprocessing. + All data producer has an unified `generator` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle any of your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch. + Use Python to easily handle any data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
For example, InceptionV3 can run in the same speed as the official code which reads data using TF operators. For example, InceptionV3 can run in the same speed as the official code which reads data using TF operators.
3. Callbacks, including everything you want to do apart from the training iterations, such as: 3. Callbacks, including everything you want to do apart from the training iterations, such as:
...@@ -39,6 +37,7 @@ Describe your training task with three components: ...@@ -39,6 +37,7 @@ Describe your training task with three components:
With the above components defined, tensorpack trainer will run the training iterations for you. With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is off-the-shelf by simply switching the trainer. Multi-GPU training is off-the-shelf by simply switching the trainer.
You can also define your own trainer for non-standard training (e.g. GAN).
## Dependencies: ## Dependencies:
......
...@@ -174,12 +174,12 @@ def sample(datadir, model_path): ...@@ -174,12 +174,12 @@ def sample(datadir, model_path):
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = BatchData(MapData(ds, lambda dp: split_input(dp[0])), 16) ds = BatchData(MapData(ds, lambda dp: split_input(dp[0])), 6)
pred = SimpleDatasetPredictor(pred, ds) pred = SimpleDatasetPredictor(pred, ds)
for o in pred.get_result(): for o in pred.get_result():
o = o[:,:,:,::-1] o = o[0][:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=4, nr_col=4, viz=True)) viz = next(build_patch_list(o, nr_row=3, nr_col=2, viz=True))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -200,4 +200,3 @@ if __name__ == '__main__': ...@@ -200,4 +200,3 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
GANTrainer(config, g_vs_d=1).train() GANTrainer(config, g_vs_d=1).train()
# Generative Adversarial Networks # Generative Adversarial Networks
See the docstring in the script for detailed usage.
## DCGAN-CelebA.py ## DCGAN-CelebA.py
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch). Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTvr2BBkLUF2M0RXU1NYSkE?usp=sharing) on CelebA face dataset: Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTvr2BBkLUF2M0RXU1NYSkE?usp=sharing) on CelebA face dataset:
1. Generated samples + Generated samples
![sample](demo/CelebA-samples.jpg) ![sample](demo/CelebA-samples.jpg)
2. Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man + Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man
![vec](demo/CelebA-vec.jpg) ![vec](demo/CelebA-vec.jpg)
See the docstring in the script for usage.
## Image2Image.py ## Image2Image.py
Reproduce [Image-to-image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf), Reproduce [Image-to-image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf),
......
...@@ -18,7 +18,7 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor', ...@@ -18,7 +18,7 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
class PredictorBase(object): class PredictorBase(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
""" """
Property: Available attributes:
session session
return_input return_input
""" """
......
...@@ -78,7 +78,7 @@ def build_patch_list(patch_list, ...@@ -78,7 +78,7 @@ def build_patch_list(patch_list,
""" """
Generate patches. Generate patches.
:param patch_list: bhw or bhwc images in [0,255] :param patch_list: bhw or bhwc images in [0,255]
:param border: defaults to 0.1 * max(image_width, image_height) :param border: defaults to 0.1 * min(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid :param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols :parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:param shuffle: shuffle the images :param shuffle: shuffle the images
...@@ -97,7 +97,7 @@ def build_patch_list(patch_list, ...@@ -97,7 +97,7 @@ def build_patch_list(patch_list,
viz = True viz = True
ph, pw = patch_list.shape[1:3] ph, pw = patch_list.shape[1:3]
if border is None: if border is None:
border = int(0.1 * max(ph, pw)) border = int(0.1 * min(ph, pw))
mh, mw = max(max_height, ph + border), max(max_width, pw + border) mh, mw = max(max_height, ph + border), max(max_width, pw + border)
if nr_row is None: if nr_row is None:
nr_row = minnone(nr_row, max_height / (ph + border)) nr_row = minnone(nr_row, max_height / (ph + border))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment