Commit 8843f7ff authored by Yuxin Wu's avatar Yuxin Wu

viz and GAN demo

parent 79dbd183
......@@ -5,14 +5,13 @@
import numpy as np
import tensorflow as tf
import glob
import glob, pickle
import os, sys
import argparse
import cv2
from tensorpack import *
from tensorpack.utils.viz import build_patch_list
from tensorpack.utils.viz import dump_dataflow_images
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses
......@@ -119,21 +118,38 @@ def sample(model_path):
session_init=get_model_loader(model_path),
model=Model(),
input_names=['z'],
output_names=['gen/gen'])
output_names=['gen/gen', 'z'])
pred = SimpleDatasetPredictor(pred, RandomZData((100, 100)))
for o in pred.get_result():
o = o[0] + 1
o, zs = o[0] + 1, o[1]
o = o * 128.0
o = o[:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=10, nr_col=10))
cv2.imshow("", viz)
cv2.waitKey()
viz = next(build_patch_list(o, nr_row=10, nr_col=10, viz=True))
def interp(model_path):
func = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
input_names=['z'],
output_names=['gen/gen']))
dic = np.load('demo/CelebA-vec.npy').item()
assert np.all(
dic['w_smile'] - dic['w_neutral'] \
+ dic['m_neutral'] == dic['m_smile'])
imgs = []
for z in ['w_neutral', 'w_smile', 'm_neutral', 'm_smile']:
z = dic[z]
img = func([[z]])[0][0][:,:,::-1]
img = (img + 1) * 128
imgs.append(img)
viz = next(build_patch_list(imgs, nr_row=1, nr_col=4, viz=True))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--interp', action='store_true', help='run interpolation')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
global args
args = parser.parse_args()
......@@ -141,6 +157,8 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
sample(args.load)
elif args.interp:
interp(args.load)
else:
assert args.data
config = get_config()
......
......@@ -2,6 +2,12 @@
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
More results to come.
Samples from CelebA dataset:
See the docstring in each executable script for usage.
![sample](demo/CelebA-samples.jpg)
Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man
![vec](demo/CelebA-vec.jpg)
See the docstring in the script for usage.
......@@ -16,7 +16,7 @@ except ImportError:
pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz',
'dump_dataflow_images']
'dump_dataflow_images', 'interactive_imshow']
def pyplot2img(plt):
buf = io.BytesIO()
......@@ -45,21 +45,55 @@ def minnone(x, y):
elif y is None: y = x
return min(x, y)
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
:param lclick_cb: a callback(img, x, y) for left click
:param kwargs: can be {key_cb_a ... key_cb_z: callback(img)}
"""
name = 'random_window_name'
cv2.imshow(name, img)
def mouse_cb(event, x, y, *args):
if event == cv2.EVENT_LBUTTONUP and lclick_cb is not None:
lclick_cb(img, x, y)
elif event == cv2.EVENT_RBUTTONUP and rclick_cb is not None:
rclick_cb(img, x, y)
cv2.setMouseCallback(name, mouse_cb)
key = chr(cv2.waitKey(-1) & 0xff)
cb_name = 'key_cb_' + key
if cb_name in kwargs:
kwargs[cb_name](img)
elif key == 'q':
cv2.destroyWindow(name)
elif key == 'x':
sys.exit()
elif key == 's':
cv2.imwrite('out.png', img)
def build_patch_list(patch_list,
nr_row=None, nr_col=None, border=None,
max_width=1000, max_height=1000,
shuffle=False, bgcolor=255):
shuffle=False, bgcolor=255, viz=False, lclick_cb=None):
"""
This is a generator.
patch_list: bhw or bhwc
Generate patches.
:param patch_list: bhw or bhwc
:param border: defaults to 0.1 * max(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:param shuffle: shuffle the images
:param bgcolor: background color
:param viz: use interactive imshow to visualize the results
:param lclick_cb: only useful when viz=True. a callback(patch, idx)
"""
# setup parameters
patch_list = np.asarray(patch_list)
if patch_list.ndim == 3:
patch_list = patch_list[:,:,:,np.newaxis]
assert patch_list.ndim == 4 and patch_list.shape[3] in [1, 3], patch_list.shape
if shuffle:
np.random.shuffle(patch_list)
if lclick_cb is not None:
viz = True
ph, pw = patch_list.shape[1:3]
if border is None:
border = int(0.1 * max(ph, pw))
......@@ -87,12 +121,24 @@ def build_patch_list(patch_list,
nr_patch = nr_row * nr_col
start = 0
def lclick_callback(img, x, y):
if lclick_cb is None:
return
x = x // (pw + border)
y = y // (pw + border)
idx = start + y * nr_col + x
if idx < end:
lclick_cb(patch_list[idx], idx)
while True:
end = start + nr_patch
cur_list = patch_list[start:end]
if not len(cur_list):
return
draw_patch(cur_list)
if viz:
interactive_imshow(canvas, lclick_cb=lclick_callback)
yield canvas
start = end
......@@ -150,9 +196,7 @@ def dump_dataflow_images(df, index=0, batched=True,
if viz is not None and len(vizlist) >= vizsize:
patch = next(build_patch_list(
vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1]))
cv2.imshow("df-viz", patch)
cv2.waitKey()
nr_row=viz[0], nr_col=viz[1], viz=True))
vizlist = vizlist[vizsize:]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment