Commit 8843f7ff authored by Yuxin Wu's avatar Yuxin Wu

viz and GAN demo

parent 79dbd183
...@@ -5,14 +5,13 @@ ...@@ -5,14 +5,13 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import glob import glob, pickle
import os, sys import os, sys
import argparse import argparse
import cv2 import cv2
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import build_patch_list from tensorpack.utils.viz import *
from tensorpack.utils.viz import dump_dataflow_images
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses from GAN import GANTrainer, RandomZData, build_GAN_losses
...@@ -119,21 +118,38 @@ def sample(model_path): ...@@ -119,21 +118,38 @@ def sample(model_path):
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
model=Model(), model=Model(),
input_names=['z'], input_names=['z'],
output_names=['gen/gen']) output_names=['gen/gen', 'z'])
pred = SimpleDatasetPredictor(pred, RandomZData((100, 100))) pred = SimpleDatasetPredictor(pred, RandomZData((100, 100)))
for o in pred.get_result(): for o in pred.get_result():
o = o[0] + 1 o, zs = o[0] + 1, o[1]
o = o * 128.0 o = o * 128.0
o = o[:,:,:,::-1] o = o[:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=10, nr_col=10)) viz = next(build_patch_list(o, nr_row=10, nr_col=10, viz=True))
cv2.imshow("", viz)
cv2.waitKey() def interp(model_path):
func = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
input_names=['z'],
output_names=['gen/gen']))
dic = np.load('demo/CelebA-vec.npy').item()
assert np.all(
dic['w_smile'] - dic['w_neutral'] \
+ dic['m_neutral'] == dic['m_smile'])
imgs = []
for z in ['w_neutral', 'w_smile', 'm_neutral', 'm_smile']:
z = dic[z]
img = func([[z]])[0][0][:,:,::-1]
img = (img + 1) * 128
imgs.append(img)
viz = next(build_patch_list(imgs, nr_row=1, nr_col=4, viz=True))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--interp', action='store_true', help='run interpolation')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset') parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
global args global args
args = parser.parse_args() args = parser.parse_args()
...@@ -141,6 +157,8 @@ if __name__ == '__main__': ...@@ -141,6 +157,8 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample: if args.sample:
sample(args.load) sample(args.load)
elif args.interp:
interp(args.load)
else: else:
assert args.data assert args.data
config = get_config() config = get_config()
......
...@@ -2,6 +2,12 @@ ...@@ -2,6 +2,12 @@
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch). Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
More results to come. Samples from CelebA dataset:
See the docstring in each executable script for usage. ![sample](demo/CelebA-samples.jpg)
Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man
![vec](demo/CelebA-vec.jpg)
See the docstring in the script for usage.
...@@ -16,7 +16,7 @@ except ImportError: ...@@ -16,7 +16,7 @@ except ImportError:
pass pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz', __all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz',
'dump_dataflow_images'] 'dump_dataflow_images', 'interactive_imshow']
def pyplot2img(plt): def pyplot2img(plt):
buf = io.BytesIO() buf = io.BytesIO()
...@@ -45,21 +45,55 @@ def minnone(x, y): ...@@ -45,21 +45,55 @@ def minnone(x, y):
elif y is None: y = x elif y is None: y = x
return min(x, y) return min(x, y)
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
:param lclick_cb: a callback(img, x, y) for left click
:param kwargs: can be {key_cb_a ... key_cb_z: callback(img)}
"""
name = 'random_window_name'
cv2.imshow(name, img)
def mouse_cb(event, x, y, *args):
if event == cv2.EVENT_LBUTTONUP and lclick_cb is not None:
lclick_cb(img, x, y)
elif event == cv2.EVENT_RBUTTONUP and rclick_cb is not None:
rclick_cb(img, x, y)
cv2.setMouseCallback(name, mouse_cb)
key = chr(cv2.waitKey(-1) & 0xff)
cb_name = 'key_cb_' + key
if cb_name in kwargs:
kwargs[cb_name](img)
elif key == 'q':
cv2.destroyWindow(name)
elif key == 'x':
sys.exit()
elif key == 's':
cv2.imwrite('out.png', img)
def build_patch_list(patch_list, def build_patch_list(patch_list,
nr_row=None, nr_col=None, border=None, nr_row=None, nr_col=None, border=None,
max_width=1000, max_height=1000, max_width=1000, max_height=1000,
shuffle=False, bgcolor=255): shuffle=False, bgcolor=255, viz=False, lclick_cb=None):
""" """
This is a generator. Generate patches.
patch_list: bhw or bhwc :param patch_list: bhw or bhwc
:param border: defaults to 0.1 * max(image_width, image_height) :param border: defaults to 0.1 * max(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:param shuffle: shuffle the images
:param bgcolor: background color
:param viz: use interactive imshow to visualize the results
:param lclick_cb: only useful when viz=True. a callback(patch, idx)
""" """
# setup parameters
patch_list = np.asarray(patch_list) patch_list = np.asarray(patch_list)
if patch_list.ndim == 3: if patch_list.ndim == 3:
patch_list = patch_list[:,:,:,np.newaxis] patch_list = patch_list[:,:,:,np.newaxis]
assert patch_list.ndim == 4 and patch_list.shape[3] in [1, 3], patch_list.shape assert patch_list.ndim == 4 and patch_list.shape[3] in [1, 3], patch_list.shape
if shuffle: if shuffle:
np.random.shuffle(patch_list) np.random.shuffle(patch_list)
if lclick_cb is not None:
viz = True
ph, pw = patch_list.shape[1:3] ph, pw = patch_list.shape[1:3]
if border is None: if border is None:
border = int(0.1 * max(ph, pw)) border = int(0.1 * max(ph, pw))
...@@ -87,12 +121,24 @@ def build_patch_list(patch_list, ...@@ -87,12 +121,24 @@ def build_patch_list(patch_list,
nr_patch = nr_row * nr_col nr_patch = nr_row * nr_col
start = 0 start = 0
def lclick_callback(img, x, y):
if lclick_cb is None:
return
x = x // (pw + border)
y = y // (pw + border)
idx = start + y * nr_col + x
if idx < end:
lclick_cb(patch_list[idx], idx)
while True: while True:
end = start + nr_patch end = start + nr_patch
cur_list = patch_list[start:end] cur_list = patch_list[start:end]
if not len(cur_list): if not len(cur_list):
return return
draw_patch(cur_list) draw_patch(cur_list)
if viz:
interactive_imshow(canvas, lclick_cb=lclick_callback)
yield canvas yield canvas
start = end start = end
...@@ -150,9 +196,7 @@ def dump_dataflow_images(df, index=0, batched=True, ...@@ -150,9 +196,7 @@ def dump_dataflow_images(df, index=0, batched=True,
if viz is not None and len(vizlist) >= vizsize: if viz is not None and len(vizlist) >= vizsize:
patch = next(build_patch_list( patch = next(build_patch_list(
vizlist[:vizsize], vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1])) nr_row=viz[0], nr_col=viz[1], viz=True))
cv2.imshow("df-viz", patch)
cv2.waitKey()
vizlist = vizlist[vizsize:] vizlist = vizlist[vizsize:]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment