Commit dcf55733 authored by Yuxin Wu's avatar Yuxin Wu

add image dump callback

parent 96255c9a
...@@ -32,6 +32,9 @@ class Callback(object): ...@@ -32,6 +32,9 @@ class Callback(object):
""" """
def after_train(self): def after_train(self):
self._after_train()
def _after_train(self):
""" """
Called after training Called after training
""" """
......
...@@ -56,6 +56,6 @@ class SummaryWriter(Callback): ...@@ -56,6 +56,6 @@ class SummaryWriter(Callback):
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value)) logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
self.writer.add_summary(summary, get_global_step()) self.writer.add_summary(summary, get_global_step())
def after_train(self): def _after_train(self):
self.writer.close() self.writer.close()
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import Callback
import cv2
import os
from ..utils import logger
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback):
def __init__(self, var_name, prefix=None, map_func=None, scale=255):
"""
map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c].
scale: a scaling parameter on pixels
"""
self.var_name = var_name
self.func = map_func
if prefix is None:
self.prefix = self.var_name
else:
self.prefix = prefix
self.log_dir = logger.LOG_DIR
self.scale = scale
def _before_train(self):
self.var = self.graph.get_tensor_by_name(self.var_name)
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
val = self.sess.run(self.var)
if self.func is not None:
val = self.func(val)
if isinstance(val, list):
for idx, im in enumerate(val):
assert im.ndim in [2, 3], str(im.ndim)
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{}-{}.png'.format(self.epoch_num, idx))
cv2.imwrite(fname, im * self.scale)
else:
im = val
assert im.ndim in [2, 3]
fname = os.path.join(
self.log_dir,
self.prefix + '-ep{}.png'.format(self.epoch_num))
cv2.imwrite(fname, im * self.scale)
...@@ -78,7 +78,7 @@ class TrainCallbacks(Callback): ...@@ -78,7 +78,7 @@ class TrainCallbacks(Callback):
cb.before_train() cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
def after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.after_train() cb.after_train()
...@@ -115,7 +115,7 @@ class TestCallbacks(Callback): ...@@ -115,7 +115,7 @@ class TestCallbacks(Callback):
for cb in self.cbs: for cb in self.cbs:
cb.before_train() cb.before_train()
def after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.after_train() cb.after_train()
...@@ -161,7 +161,7 @@ class Callbacks(Callback): ...@@ -161,7 +161,7 @@ class Callbacks(Callback):
self.train.before_train() self.train.before_train()
self.test.before_train() self.test.before_train()
def after_train(self): def _after_train(self):
self.train.after_train() self.train.after_train()
self.test.after_train() self.test.after_train()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment