Commit 412acd12 authored by Yuxin Wu's avatar Yuxin Wu

docs, fix resize

parent f16aef9d
...@@ -14,17 +14,21 @@ from tensorpack import * ...@@ -14,17 +14,21 @@ from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses from GAN import GANTrainer, build_GAN_losses
""" """
To train: To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA} ./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain images of shpae 2s x s, formed by A and B # datadir should contain jpg images of shpae 2s x s, formed by A and B
# you can download some data from the original pix2pix repo: https://github.com/phillipi/pix2pix#datasets # you can download some data from the original authors: https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
# training visualization will appear be in tensorboard # training visualization will appear be in tensorboard
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
To visualize on test set: To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load pretrained.model ./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
""" """
SHAPE = 256 SHAPE = 256
...@@ -41,7 +45,7 @@ class Model(ModelDesc): ...@@ -41,7 +45,7 @@ class Model(ModelDesc):
def generator(self, imgs): def generator(self, imgs):
# imgs: input: 256x256xch # imgs: input: 256x256xch
# U-Net structure, slightly different from the original on the location of relu/lrelu # U-Net structure, it's slightly different from the original on the location of relu/lrelu
with argscope(BatchNorm, use_local_stat=True), \ with argscope(BatchNorm, use_local_stat=True), \
argscope(Dropout, is_training=True): argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing # always use local stat for BN, and apply dropout even in testing
...@@ -118,7 +122,7 @@ class Model(ModelDesc): ...@@ -118,7 +122,7 @@ class Model(ModelDesc):
fake_output = tf.image.grayscale_to_rgb(fake_output) fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0 viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.summary.image('gen', viz, max_outputs=max(30, BATCH)) tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH))
all_vars = tf.trainable_variables() all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')] self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
...@@ -134,9 +138,9 @@ def split_input(img): ...@@ -134,9 +138,9 @@ def split_input(img):
if args.mode == 'BtoA': if args.mode == 'BtoA':
input, output = output, input input, output = output, input
if IN_CH == 1: if IN_CH == 1:
input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY) input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis]
if OUT_CH == 1: if OUT_CH == 1:
output = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY) output = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis]
return [input, output] return [input, output]
def get_data(): def get_data():
......
...@@ -8,7 +8,7 @@ Reproduce the following GAN-related papers: ...@@ -8,7 +8,7 @@ Reproduce the following GAN-related papers:
+ InfoGAN: Interpretable Representation Learning by Information Maximizing GAN. [paper](https://arxiv.org/abs/1606.03657) + InfoGAN: Interpretable Representation Learning by Information Maximizing GAN. [paper](https://arxiv.org/abs/1606.03657)
Detailed usage is in the docstring of each script. Please see the __docstring__ in each script for detailed usage.
## DCGAN-CelebA.py ## DCGAN-CelebA.py
...@@ -27,7 +27,6 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv ...@@ -27,7 +27,6 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv
## Image2Image.py ## Image2Image.py
Image-to-Image following the setup in [pix2pix](https://github.com/phillipi/pix2pix). Image-to-Image following the setup in [pix2pix](https://github.com/phillipi/pix2pix).
It requires the datasets released by the original authors.
For example, with the cityscapes dataset, it learns to generate semantic segmentation map of urban scene: For example, with the cityscapes dataset, it learns to generate semantic segmentation map of urban scene:
......
...@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow): ...@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow):
:param channel: 1 or 3 channel :param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize :param resize: a (h, w) tuple. If given, will force a resize
""" """
assert len(files), "No Image Files!" assert len(files), "No image files given to ImageFromFile!"
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
......
...@@ -55,9 +55,12 @@ class Resize(ImageAugmentor): ...@@ -55,9 +55,12 @@ class Resize(ImageAugmentor):
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
return cv2.resize( ret = cv2.resize(
img, self.shape[::-1], img, self.shape[::-1],
interpolation=self.interp) interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
return ret
class ResizeShortestEdge(ImageAugmentor): class ResizeShortestEdge(ImageAugmentor):
""" Resize the shortest edge to a certain number while """ Resize the shortest edge to a certain number while
...@@ -71,8 +74,10 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -71,8 +74,10 @@ class ResizeShortestEdge(ImageAugmentor):
h, w = img.shape[:2] h, w = img.shape[:2]
scale = self.size / min(h, w) scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h]) desSize = map(int, [scale * w, scale * h])
img = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC) ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC)
return img if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
return ret
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image""" """ randomly rescale w and h of the image"""
...@@ -105,5 +110,8 @@ class RandomResize(ImageAugmentor): ...@@ -105,5 +110,8 @@ class RandomResize(ImageAugmentor):
return img.shape[1], img.shape[0] return img.shape[1], img.shape[0]
def _augment(self, img, dsize): def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=self.interp) ret = cv2.resize(img, dsize, interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment