Commit 9a156417 authored by Yuxin Wu's avatar Yuxin Wu

EnableCallbackIf & DumpTensor

parent ea0342e5
......@@ -228,11 +228,15 @@ class ProxyCallback(Callback):
self.cb.before_train()
def _setup_graph(self):
self.cb.setup_graph(self.trainer)
with tf.name_scope(None):
self.cb.setup_graph(self.trainer)
def _trigger_epoch(self):
self.cb.trigger_epoch()
def _trigger(self):
self.cb.trigger()
def _trigger_step(self):
self.cb.trigger_step()
......
......@@ -106,6 +106,7 @@ class GraphProfiler(Callback):
The metadata files can be processed by
`tfprof command line utils
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/tfprof/g3doc/command_line.md>`_.
The event is viewable from tensorboard.
Note that the profiling is enabled for every step.
You probably want to schedule it less frequently by
......
......@@ -3,13 +3,15 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import tensorflow as tf
import numpy as np
from six.moves import zip
from .base import Callback
from ..utils import logger
from ..tfutils.common import get_op_tensor_name
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell']
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell', 'DumpTensor']
class SendStat(Callback):
......@@ -121,6 +123,38 @@ class DumpParamAsImage(Callback):
cv2.imwrite(fname, res.astype('uint8'))
class DumpTensor(Callback):
"""
Dump some tensors to a file.
Every step this callback fetches tensors and write them to a npz file under ``logger.LOG_DIR``.
The dump can be loaded by ``dict(np.load(filename).items())``.
"""
# TODO run as trigger
def __init__(self, names):
"""
Args:
names (list[str]): names of tensors
"""
self._names = names
self._dir = logger.LOG_DIR
def _setup_graph(self):
tensors = get_tensors_by_names(self._names)
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
return self._fetch
def _after_run(self, _, rv):
results = rv.results
dic = {}
for name, val in zip(self._names, results):
dic[name] = val
fname = os.path.join(
self._dir, 'DumpTensor-{}.npz'.format(self.global_step))
np.savez(fname, **dic)
try:
import cv2
except ImportError:
......
......@@ -5,7 +5,7 @@
from .base import ProxyCallback, Callback
__all__ = ['PeriodicTrigger', 'PeriodicRunHooks']
__all__ = ['PeriodicTrigger', 'PeriodicRunHooks', 'EnableCallbackIf']
class PeriodicTrigger(ProxyCallback):
......@@ -71,3 +71,51 @@ class PeriodicRunHooks(ProxyCallback):
def __str__(self):
return "PeriodicRunHooks-" + str(self.cb)
class EnableCallbackIf(ProxyCallback):
"""
Enable ``{before,after}_epoch``, ``{before,after}_run``, ``trigger*``
methods of a callback, only when some condition satisfies.
The other methods will be called the same.
Note:
If you need to use ``{before,after}_run``, make sure
that ``pred`` will eval to the same results in both methods every step.
"""
def __init__(self, callback, pred):
"""
Args:
callback (Callback):
pred (self -> bool): a callable predicate
"""
self._pred = pred
super(EnableCallbackIf, self).__init__(callback)
def _before_run(self, ctx):
if self._pred(self):
super(EnableCallbackIf, self)._before_run(ctx)
def _after_run(self, ctx, rv):
if self._pred(self):
super(EnableCallbackIf, self)._after_run(ctx, rv)
def _before_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._before_epoch()
def _after_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._after_epoch()
def _trigger(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger()
def _trigger_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger_epoch()
def _trigger_step(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger_step()
......@@ -105,7 +105,10 @@ def print_stat(x, message=None):
"""
if message is None:
message = x.op.name
return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), rms(x), x], summarize=20,
lst = [tf.shape(x), tf.reduce_mean(x)]
if x.dtype.is_floating:
lst.append(rms(x))
return tf.Print(x, lst + [x], summarize=20,
message=message, name='print_' + x.op.name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment