Commit 00fdb263 authored by Yuxin Wu's avatar Yuxin Wu

refactor in viz. stack_patches now handles nonuniform patches as well.

parent d6a89f0b
......@@ -125,7 +125,7 @@ def sample(model_path):
pred = SimpleDatasetPredictor(pred, ds)
for o in pred.get_result():
o = o[0] * 255.0
viz = next(build_patch_list(o, nr_row=10, nr_col=10))
viz = stack_patches(o, nr_row=10, nr_col=10)
viz = cv2.resize(viz, (800, 800))
interactive_imshow(viz)
......
......@@ -121,7 +121,7 @@ def sample(model_path):
o, zs = o[0] + 1, o[1]
o = o * 128.0
o = o[:, :, :, ::-1]
viz = next(build_patch_list(o, nr_row=10, nr_col=10, viz=True))
viz = stack_patches(o, nr_row=10, nr_col=10, viz=True)
if __name__ == '__main__':
......
......@@ -196,7 +196,7 @@ def sample(datadir, model_path):
pred = SimpleDatasetPredictor(pred, ds)
for o in pred.get_result():
o = o[0][:, :, :, ::-1]
next(build_patch_list(o, nr_row=3, nr_col=2, viz=True))
stack_patches(o, nr_row=3, nr_col=2, viz=True)
if __name__ == '__main__':
......
......@@ -192,24 +192,24 @@ def sample(model_path):
z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
zc = np.concatenate((z_cat, z_uni * 0, z_uni * 0), axis=1)
o = pred(zc, z_noise)[0]
viz1 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz1 = stack_patches(o, nr_row=10, nr_col=10)
viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE))
# show effect of first continous variable with fixed noise
zc = np.concatenate((z_cat, z_uni, z_uni * 0), axis=1)
o = pred(zc, z_noise * 0)[0]
viz2 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz2 = stack_patches(o, nr_row=10, nr_col=10)
viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE))
# show effect of second continous variable with fixed noise
zc = np.concatenate((z_cat, z_uni * 0, z_uni), axis=1)
o = pred(zc, z_noise * 0)[0]
viz3 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz3 = stack_patches(o, nr_row=10, nr_col=10)
viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE))
viz = next(build_patch_list(
viz = stack_patches(
[viz1, viz2, viz3],
nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0)))
nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0))
interactive_imshow(viz)
......
......@@ -124,7 +124,7 @@ class AccumGradOptimizer(ProxyOptimizer):
An optimizer which accumulates gradients across :math:`k` :meth:`minimize` calls,
and apply them together in every :math:`k`th :meth:`minimize` call.
This is equivalent to using a :math:`k` times larger batch size plus a
:math:`k` times larger learning rate, but use much less memory.
:math:`k` times larger learning rate, but uses much less memory.
"""
def __init__(self, opt, niter):
......
......@@ -17,8 +17,9 @@ except ImportError:
pass
__all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'pyplot_viz', 'dump_dataflow_images', 'intensity_to_rgb', 'stack_images']
__all__ = ['pyplot2img', 'pyplot_viz', 'interactive_imshow',
'stack_patches', 'gen_stack_patches',
'dump_dataflow_images', 'intensity_to_rgb']
def pyplot2img(plt):
......@@ -51,14 +52,6 @@ def pyplot_viz(img, shape=None):
return ret
def minnone(x, y):
if x is None:
x = y
elif y is None:
y = x
return min(x, y)
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
Args:
......@@ -94,96 +87,197 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
cv2.imwrite('out.png', img)
def build_patch_list(patch_list,
nr_row=None, nr_col=None, border=None,
max_width=1000, max_height=1000,
shuffle=False, bgcolor=255,
viz=False, lclick_cb=None):
def _preproecss_patch_list(plist):
plist = np.asarray(plist)
if plist.ndim == 3:
plist = plist[:, :, :, np.newaxis]
assert plist.ndim == 4 and plist.shape[3] in [1, 3], plist.shape
return plist
def _pad_patch_list(plist, bgcolor):
if isinstance(bgcolor, int):
bgcolor = (bgcolor, bgcolor, bgcolor)
def _pad_channel(plist):
ret = []
for p in plist:
if len(p.shape) == 2:
p = p[:, :, np.newaxis]
if p.shape[2] == 1:
p = np.repeat(p, 3, 2)
ret.append(p)
return ret
plist = _pad_channel(plist)
shapes = [x.shape for x in plist]
ph = max([s[0] for s in shapes])
pw = max([s[1] for s in shapes])
ret = np.zeros((len(plist), ph, pw, 3), dtype=plist[0].dtype)
ret[:, :, :] = bgcolor
for idx, p in enumerate(plist):
s = p.shape
sh = (ph - s[0]) / 2
sw = (pw - s[1]) / 2
ret[idx, sh:sh + s[0], sw:sw + s[1], :] = p
return ret
class Canvas(object):
def __init__(self, ph, pw,
nr_row, nr_col,
channel, border, bgcolor):
self.ph = ph
self.pw = pw
self.nr_row = nr_row
self.nr_col = nr_col
if border is None:
border = int(0.1 * min(ph, pw))
self.border = border
if isinstance(bgcolor, int):
bgchannel = 1
else:
bgchannel = 3
self.bgcolor = bgcolor
self.channel = max(channel, bgchannel)
self.canvas = np.zeros((nr_row * (ph + border) - border,
nr_col * (pw + border) - border,
self.channel), dtype='uint8')
def draw_patches(self, plist):
assert self.nr_row * self.nr_col == len(plist), \
"{}*{} != {}".format(self.nr_row, self.nr_col, len(plist))
if self.channel == 3 and plist.shape[3] == 1:
plist = np.repeat(plist, 3, axis=3)
cur_row, cur_col = 0, 0
if self.channel == 1:
self.canvas.fill(self.bgcolor)
else:
self.canvas[:, :, :] = self.bgcolor
for patch in plist:
r0 = cur_row * (self.ph + self.border)
c0 = cur_col * (self.pw + self.border)
self.canvas[r0:r0 + self.ph, c0:c0 + self.pw] = patch
cur_col += 1
if cur_col == self.nr_col:
cur_col = 0
cur_row += 1
def get_patchid_from_coord(self, x, y):
x = x // (self.pw + self.border)
y = y // (self.pw + self.border)
idx = y * self.nr_col + x
return idx
def stack_patches(
patch_list, nr_row, nr_col, border=None,
pad=False, bgcolor=255, viz=False, lclick_cb=None):
"""
Stacked patches into grid, to produce visualizations like the following:
.. image:: https://github.com/ppwwyyxx/tensorpack/raw/master/examples/GAN/demo/CelebA-samples.jpg
Args:
patch_list(np.ndarray): NHW or NHWC images in [0,255].
patch_list(list[ndarray] or ndarray): NHW or NHWC images in [0,255].
nr_row(int), nr_col(int): rows and cols of the grid.
``nr_col * nr_row`` must be equal to ``len(patch_list)``.
border(int): border length between images.
Defaults to ``0.1 * min(image_w, image_h)``.
max_width(int), max_height(int): Maximum allowed size of the
visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols.
shuffle(bool): shuffle the images inside ``patch_list``.
pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width.
This option allows stacking patches of different shapes together.
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
or a BGR tuple.
viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
lclick_cb: A callback function to get called when ``viz==True`` and an
image get clicked. It takes the image patch and its index in
``patch_list`` as arguments. (The index is invalid when
``shuffle==True``.)
lclick_cb: A callback function ``f(patch, patch index in patch_list)``
to get called when a patch get clicked in imshow.
Returns:
np.ndarray: the stacked image.
"""
if pad:
patch_list = _pad_patch_list(patch_list)
patch_list = _preproecss_patch_list(patch_list)
if lclick_cb is not None:
viz = True
ph, pw = patch_list.shape[1:3]
canvas = Canvas(ph, pw, nr_row, nr_col,
patch_list.shape[-1], border, bgcolor)
if lclick_cb is not None:
def lclick_callback(img, x, y):
idx = canvas.get_patchid_from_coord(x, y)
lclick_cb(patch_list[idx], idx)
else:
lclick_callback = None
canvas.draw_patches(patch_list)
if viz:
interactive_imshow(canvas.canvas, lclick_cb=lclick_callback)
return canvas
def gen_stack_patches(patch_list,
nr_row=None, nr_col=None, border=None,
max_width=1000, max_height=1000,
bgcolor=255, viz=False, lclick_cb=None):
"""
Similar to :func:`stack_patches` but with a generator interface.
It takes a much-longer list and yields stacked results one by one.
For example, if ``patch_list`` contains 1000 images and ``nr_row==nr_col==10``,
this generator yields 10 stacked images.
Args:
nr_row(int), nr_col(int): rows and cols of each result.
max_width(int), max_height(int): Maximum allowed size of the
stacked image. If ``nr_row/nr_col`` are None, this number
will be used to infer the rows and cols. Otherwise the option is
ignored.
patch_list, border, viz, lclick_cb: same as in :func:`stack_patches`.
Yields:
np.ndarray: the visualization image.
np.ndarray: the stacked image.
"""
# setup parameters
patch_list = np.asarray(patch_list)
if patch_list.ndim == 3:
patch_list = patch_list[:, :, :, np.newaxis]
assert patch_list.ndim == 4 and patch_list.shape[3] in [1, 3], patch_list.shape
if shuffle:
np.random.shuffle(patch_list)
patch_list = _preproecss_patch_list(patch_list)
if lclick_cb is not None:
viz = True
ph, pw = patch_list.shape[1:3]
if border is None:
border = int(0.1 * min(ph, pw))
if nr_row is None:
nr_row = minnone(nr_row, max_height / (ph + border))
nr_row = int(max_height / (ph + border))
if nr_col is None:
nr_col = minnone(nr_col, max_width / (pw + border))
if isinstance(bgcolor, int):
bgchannel = 1
else:
bgchannel = 3
canvas_channel = max(patch_list.shape[3], bgchannel)
canvas = np.zeros((nr_row * (ph + border) - border,
nr_col * (pw + border) - border,
canvas_channel), dtype='uint8')
def draw_patch(plist):
cur_row, cur_col = 0, 0
if bgchannel == 1:
canvas.fill(bgcolor)
else:
canvas[:, :, :] = bgcolor
for patch in plist:
r0 = cur_row * (ph + border)
c0 = cur_col * (pw + border)
canvas[r0:r0 + ph, c0:c0 + pw] = patch
cur_col += 1
if cur_col == nr_col:
cur_col = 0
cur_row += 1
nr_col = int(max_width / (pw + border))
canvas = Canvas(ph, pw, nr_row, nr_col, patch_list.shape[-1], border, bgcolor)
nr_patch = nr_row * nr_col
start = 0
def lclick_callback(img, x, y):
if lclick_cb is None:
return
x = x // (pw + border)
y = y // (pw + border)
idx = start + y * nr_col + x
if idx < end:
lclick_cb(patch_list[idx], idx)
if lclick_cb is not None:
def lclick_callback(img, x, y):
idx = canvas.get_patchid_from_coord(x, y)
idx = idx + start
if idx < end:
lclick_cb(patch_list[idx], idx)
else:
lclick_callback = None
while True:
end = start + nr_patch
cur_list = patch_list[start:end]
if not len(cur_list):
return
draw_patch(cur_list)
canvas.draw_patches(cur_list)
if viz:
interactive_imshow(canvas, lclick_cb=lclick_callback)
interactive_imshow(canvas.canvas, lclick_cb=lclick_callback)
yield canvas
start = end
......@@ -205,7 +299,7 @@ def dump_dataflow_images(df, index=0, batched=True,
scale (float): scale the value, usually either 1 or 255.
resize (tuple or None): tuple of (h, w) to resize the images to.
viz (tuple or None): tuple of (h, w) determining the grid size to use
with :func:`build_patch_list` for visualization. No visualization will happen by
with :func:`gen_stack_patches` for visualization. No visualization will happen by
default.
flipRGB (bool): apply a RGB<->BGR conversion or not.
"""
......@@ -242,9 +336,9 @@ def dump_dataflow_images(df, index=0, batched=True,
if viz is not None:
vizlist.append(img)
if viz is not None and len(vizlist) >= vizsize:
next(build_patch_list(
stack_patches(
vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1], viz=True))
nr_row=viz[0], nr_col=viz[1], viz=True)
vizlist = vizlist[vizsize:]
......@@ -276,48 +370,18 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
return intensity.astype('float32') * 255.0
def stack_images(imgs, vertical=False):
"""Stack images with different shapes and different number of channels.
Args:
imgs (np.array): imgage
vertical (bool, optional): stack images vertically
Returns:
np.array: stacked images
"""
rows = [x.shape[0] for x in imgs]
cols = [x.shape[1] for x in imgs]
if vertical:
if len(imgs[0].shape) == 2:
out = np.zeros((np.sum(rows), max(cols)), dtype='uint8')
else:
out = np.zeros((np.sum(rows), max(cols), 3), dtype='uint8')
else:
if len(imgs[0].shape) == 2:
out = np.zeros((max(rows), np.sum(cols)), dtype='uint8')
else:
out = np.zeros((max(rows), np.sum(cols), 3), dtype='uint8')
offset = 0
for i, img in enumerate(imgs):
assert img.max() > 1, "expect images within range [0, 255]"
if vertical:
out[offset:offset + rows[i], :cols[i]] = img
offset += rows[i]
else:
out[:rows[i], offset:offset + cols[i]] = img
offset += cols[i]
return out
if __name__ == '__main__':
imglist = []
for i in range(100):
fname = "{:03d}.png".format(i)
imglist.append(cv2.imread(fname))
for idx, patch in enumerate(build_patch_list(
imglist, max_width=500, max_height=200)):
of = "patch{:02d}.png".format(idx)
cv2.imwrite(of, patch)
if False:
imglist = []
for i in range(100):
fname = "{:03d}.png".format(i)
imglist.append(cv2.imread(fname))
for idx, patch in enumerate(gen_stack_patches(
imglist, max_width=500, max_height=200)):
of = "patch{:02d}.png".format(idx)
cv2.imwrite(of, patch)
else:
imglist = []
img = cv2.imread('out.png')
img2 = cv2.resize(img, (300, 300))
viz = stack_patches([img, img2], 1, 2, pad=True, viz=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment