Commit 56178e92 authored by Yuxin Wu's avatar Yuxin Wu

clip in dumpimageparam

parent 84ba85fd
......@@ -5,6 +5,7 @@
import os
import scipy.misc
from scipy.misc import imsave
import numpy as np
from .base import Callback
from ..utils import logger
......@@ -12,11 +13,12 @@ from ..utils import logger
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback):
def __init__(self, var_name, prefix=None, map_func=None, scale=255):
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
"""
map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c].
scale: a multiplier on pixel values, applied after map_func. default to 255
clip: clip the result to [0, 255]
"""
self.var_name = var_name
self.func = map_func
......@@ -26,6 +28,7 @@ class DumpParamAsImage(Callback):
self.prefix = prefix
self.log_dir = logger.LOG_DIR
self.scale = scale
self.clip = clip
def _before_train(self):
self.var = self.graph.get_tensor_by_name(self.var_name)
......@@ -46,5 +49,8 @@ class DumpParamAsImage(Callback):
self.log_dir,
self.prefix + '-ep{:03d}{}.png'.format(
self.epoch_num, '-' + str(idx) if idx else ''))
imsave(fname, (im * self.scale).astype('uint8'))
res = im * self.scale
if self.clip:
res = np.clip(res, 0, 255)
imsave(fname, res.astype('uint8'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment