Commit 56178e92 authored by Yuxin Wu's avatar Yuxin Wu

clip in dumpimageparam

parent 84ba85fd
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import os import os
import scipy.misc import scipy.misc
from scipy.misc import imsave from scipy.misc import imsave
import numpy as np
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
...@@ -12,11 +13,12 @@ from ..utils import logger ...@@ -12,11 +13,12 @@ from ..utils import logger
__all__ = ['DumpParamAsImage'] __all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback): class DumpParamAsImage(Callback):
def __init__(self, var_name, prefix=None, map_func=None, scale=255): def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
""" """
map_func: map the value of the variable to an image or list of images, default to identity map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c]. images should have shape [h, w] or [h, w, c].
scale: a multiplier on pixel values, applied after map_func. default to 255 scale: a multiplier on pixel values, applied after map_func. default to 255
clip: clip the result to [0, 255]
""" """
self.var_name = var_name self.var_name = var_name
self.func = map_func self.func = map_func
...@@ -26,6 +28,7 @@ class DumpParamAsImage(Callback): ...@@ -26,6 +28,7 @@ class DumpParamAsImage(Callback):
self.prefix = prefix self.prefix = prefix
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
self.scale = scale self.scale = scale
self.clip = clip
def _before_train(self): def _before_train(self):
self.var = self.graph.get_tensor_by_name(self.var_name) self.var = self.graph.get_tensor_by_name(self.var_name)
...@@ -46,5 +49,8 @@ class DumpParamAsImage(Callback): ...@@ -46,5 +49,8 @@ class DumpParamAsImage(Callback):
self.log_dir, self.log_dir,
self.prefix + '-ep{:03d}{}.png'.format( self.prefix + '-ep{:03d}{}.png'.format(
self.epoch_num, '-' + str(idx) if idx else '')) self.epoch_num, '-' + str(idx) if idx else ''))
imsave(fname, (im * self.scale).astype('uint8')) res = im * self.scale
if self.clip:
res = np.clip(res, 0, 255)
imsave(fname, res.astype('uint8'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment