Commit 412acd12 authored by Yuxin Wu's avatar Yuxin Wu

docs, fix resize

parent f16aef9d
......@@ -14,17 +14,21 @@ from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses
from GAN import GANTrainer, build_GAN_losses
"""
To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain images of shpae 2s x s, formed by A and B
# you can download some data from the original pix2pix repo: https://github.com/phillipi/pix2pix#datasets
# datadir should contain jpg images of shpae 2s x s, formed by A and B
# you can download some data from the original authors: https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
# training visualization will appear be in tensorboard
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load pretrained.model
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
"""
SHAPE = 256
......@@ -41,7 +45,7 @@ class Model(ModelDesc):
def generator(self, imgs):
# imgs: input: 256x256xch
# U-Net structure, slightly different from the original on the location of relu/lrelu
# U-Net structure, it's slightly different from the original on the location of relu/lrelu
with argscope(BatchNorm, use_local_stat=True), \
argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing
......@@ -118,7 +122,7 @@ class Model(ModelDesc):
fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.summary.image('gen', viz, max_outputs=max(30, BATCH))
tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH))
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
......@@ -134,9 +138,9 @@ def split_input(img):
if args.mode == 'BtoA':
input, output = output, input
if IN_CH == 1:
input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)
input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis]
if OUT_CH == 1:
output = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY)
output = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis]
return [input, output]
def get_data():
......
......@@ -8,7 +8,7 @@ Reproduce the following GAN-related papers:
+ InfoGAN: Interpretable Representation Learning by Information Maximizing GAN. [paper](https://arxiv.org/abs/1606.03657)
Detailed usage is in the docstring of each script.
Please see the __docstring__ in each script for detailed usage.
## DCGAN-CelebA.py
......@@ -27,7 +27,6 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv
## Image2Image.py
Image-to-Image following the setup in [pix2pix](https://github.com/phillipi/pix2pix).
It requires the datasets released by the original authors.
For example, with the cityscapes dataset, it learns to generate semantic segmentation map of urban scene:
......
......@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow):
:param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize
"""
assert len(files), "No Image Files!"
assert len(files), "No image files given to ImageFromFile!"
self.files = files
self.channel = int(channel)
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
......
......@@ -55,9 +55,12 @@ class Resize(ImageAugmentor):
self._init(locals())
def _augment(self, img, _):
return cv2.resize(
ret = cv2.resize(
img, self.shape[::-1],
interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
return ret
class ResizeShortestEdge(ImageAugmentor):
""" Resize the shortest edge to a certain number while
......@@ -71,8 +74,10 @@ class ResizeShortestEdge(ImageAugmentor):
h, w = img.shape[:2]
scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h])
img = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC)
return img
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
return ret
class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image"""
......@@ -105,5 +110,8 @@ class RandomResize(ImageAugmentor):
return img.shape[1], img.shape[0]
def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=self.interp)
ret = cv2.resize(img, dsize, interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment