Commit 9a156417 authored by Yuxin Wu's avatar Yuxin Wu

EnableCallbackIf & DumpTensor

parent ea0342e5
...@@ -228,11 +228,15 @@ class ProxyCallback(Callback): ...@@ -228,11 +228,15 @@ class ProxyCallback(Callback):
self.cb.before_train() self.cb.before_train()
def _setup_graph(self): def _setup_graph(self):
self.cb.setup_graph(self.trainer) with tf.name_scope(None):
self.cb.setup_graph(self.trainer)
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.trigger_epoch() self.cb.trigger_epoch()
def _trigger(self):
self.cb.trigger()
def _trigger_step(self): def _trigger_step(self):
self.cb.trigger_step() self.cb.trigger_step()
......
...@@ -106,6 +106,7 @@ class GraphProfiler(Callback): ...@@ -106,6 +106,7 @@ class GraphProfiler(Callback):
The metadata files can be processed by The metadata files can be processed by
`tfprof command line utils `tfprof command line utils
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/tfprof/g3doc/command_line.md>`_. <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/tfprof/g3doc/command_line.md>`_.
The event is viewable from tensorboard.
Note that the profiling is enabled for every step. Note that the profiling is enabled for every step.
You probably want to schedule it less frequently by You probably want to schedule it less frequently by
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import tensorflow as tf
import numpy as np import numpy as np
from six.moves import zip
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell'] __all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell', 'DumpTensor']
class SendStat(Callback): class SendStat(Callback):
...@@ -121,6 +123,38 @@ class DumpParamAsImage(Callback): ...@@ -121,6 +123,38 @@ class DumpParamAsImage(Callback):
cv2.imwrite(fname, res.astype('uint8')) cv2.imwrite(fname, res.astype('uint8'))
class DumpTensor(Callback):
"""
Dump some tensors to a file.
Every step this callback fetches tensors and write them to a npz file under ``logger.LOG_DIR``.
The dump can be loaded by ``dict(np.load(filename).items())``.
"""
# TODO run as trigger
def __init__(self, names):
"""
Args:
names (list[str]): names of tensors
"""
self._names = names
self._dir = logger.LOG_DIR
def _setup_graph(self):
tensors = get_tensors_by_names(self._names)
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
return self._fetch
def _after_run(self, _, rv):
results = rv.results
dic = {}
for name, val in zip(self._names, results):
dic[name] = val
fname = os.path.join(
self._dir, 'DumpTensor-{}.npz'.format(self.global_step))
np.savez(fname, **dic)
try: try:
import cv2 import cv2
except ImportError: except ImportError:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .base import ProxyCallback, Callback from .base import ProxyCallback, Callback
__all__ = ['PeriodicTrigger', 'PeriodicRunHooks'] __all__ = ['PeriodicTrigger', 'PeriodicRunHooks', 'EnableCallbackIf']
class PeriodicTrigger(ProxyCallback): class PeriodicTrigger(ProxyCallback):
...@@ -71,3 +71,51 @@ class PeriodicRunHooks(ProxyCallback): ...@@ -71,3 +71,51 @@ class PeriodicRunHooks(ProxyCallback):
def __str__(self): def __str__(self):
return "PeriodicRunHooks-" + str(self.cb) return "PeriodicRunHooks-" + str(self.cb)
class EnableCallbackIf(ProxyCallback):
"""
Enable ``{before,after}_epoch``, ``{before,after}_run``, ``trigger*``
methods of a callback, only when some condition satisfies.
The other methods will be called the same.
Note:
If you need to use ``{before,after}_run``, make sure
that ``pred`` will eval to the same results in both methods every step.
"""
def __init__(self, callback, pred):
"""
Args:
callback (Callback):
pred (self -> bool): a callable predicate
"""
self._pred = pred
super(EnableCallbackIf, self).__init__(callback)
def _before_run(self, ctx):
if self._pred(self):
super(EnableCallbackIf, self)._before_run(ctx)
def _after_run(self, ctx, rv):
if self._pred(self):
super(EnableCallbackIf, self)._after_run(ctx, rv)
def _before_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._before_epoch()
def _after_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._after_epoch()
def _trigger(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger()
def _trigger_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger_epoch()
def _trigger_step(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger_step()
...@@ -105,7 +105,10 @@ def print_stat(x, message=None): ...@@ -105,7 +105,10 @@ def print_stat(x, message=None):
""" """
if message is None: if message is None:
message = x.op.name message = x.op.name
return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), rms(x), x], summarize=20, lst = [tf.shape(x), tf.reduce_mean(x)]
if x.dtype.is_floating:
lst.append(rms(x))
return tf.Print(x, lst + [x], summarize=20,
message=message, name='print_' + x.op.name) message=message, name='print_' + x.op.name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment