Commit 3465e1a5 authored by Yuxin Wu's avatar Yuxin Wu

ProcessTensors and DumpTensors

parent 1d74ac21
......@@ -368,7 +368,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'GaussianDeform',
'dump_chkpt_vars',
'VisualQA',
'huber_loss'
'huber_loss',
'DumpTensor'
]:
return True
if name in ['get_data', 'size', 'reset_state']:
......
......@@ -16,7 +16,6 @@ about 0.6% validation error after 30 epochs.
from tensorpack import *
from tensorpack.tfutils import summary
from tensorpack.dataflow import dataset
import tensorpack.tfutils.symbolic_functions as symbf
IMAGE_SIZE = 28
......@@ -63,15 +62,15 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbf.prediction_incorrect(logits, label, name='incorrect')
accuracy = symbf.accuracy(logits, label, name='accuracy')
# compute the "correct vector", for the callback ClassificationError to use at validation time
correct = tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32, name='correct')
accuracy = tf.reduce_mean(correct, name='accuracy')
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error = tf.reduce_mean(wrong, name='train_error')
train_error = tf.reduce_mean(1 - correct, name='train_error')
summary.add_moving_summary(train_error, accuracy)
# Use a regex to find parameters to apply weight decay.
......@@ -118,9 +117,9 @@ def get_config():
MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow
[ScalarStats('cross_entropy_loss'), ScalarStats('accuracy'),
ClassificationError('incorrect')]),
# Calculate both the cost and the accuracy for this DataFlow
[ScalarStats('cross_entropy_loss'),
ClassificationError('correct', 'validation_accuracy')]),
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
......
......@@ -6,11 +6,15 @@
""" Graph related callbacks"""
import tensorflow as tf
import os
import numpy as np
from ..utils import logger
from .base import Callback
from ..tfutils.common import get_tensors_by_names
from six.moves import zip
__all__ = ['RunOp', 'RunUpdateOps']
__all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor']
class RunOp(Callback):
......@@ -87,3 +91,65 @@ class RunUpdateOps(RunOp):
super(RunUpdateOps, self).__init__(
f, run_before=False, run_as_trigger=False, run_step=True)
class ProcessTensors(Callback):
"""
Fetch extra tensors **along with** each training step,
and call some function over the values.
You can use it to print tensors, save tensors to file, etc.
Examples:
.. code-block:: python
ProcessTensors(['mycost1', 'mycost2'], lambda c1, c2: print(c1, c2, c1 + c2))
"""
def __init__(self, names, fn):
"""
Args:
names (list[str]): names of tensors
fn: a function taking all requested tensors as input
"""
assert isinstance(names, (list, tuple)), names
self._names = names
self._fn = fn
def _setup_graph(self):
tensors = get_tensors_by_names(self._names)
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
return self._fetch
def _after_run(self, _, rv):
results = rv.results
self._fn(*results)
class DumpTensors(ProcessTensors):
"""
Dump some tensors to a file.
Every step this callback fetches tensors and write them to a npz file under ``logger.LOG_DIR``.
The dump can be loaded by ``dict(np.load(filename).items())``.
"""
def __init__(self, names):
"""
Args:
names (list[str]): names of tensors
"""
assert isinstance(names, (list, tuple)), names
self._names = names
dir = logger.LOG_DIR
def fn(*args):
dic = {}
for name, val in zip(self._names, args):
dic[name] = val
fname = os.path.join(
dir, 'DumpTensor-{}.npz'.format(self.global_step))
np.savez(fname, **dic)
super(DumpTensors, self).__init__(names, fn)
DumpTensor = DumpTensors
......@@ -143,6 +143,9 @@ class ClassificationError(Inferencer):
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
Therefore the result can be different from averaging the error rate of each batch.
You can also use the "correct prediction" tensor, so this inferencer will
give you "classification accuracy" instead of error.
"""
def __init__(self, wrong_tensor_name='incorrect_vector', summary_name='validation_error'):
......
......@@ -3,15 +3,13 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import tensorflow as tf
import numpy as np
from six.moves import zip
from .base import Callback
from ..utils import logger
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.common import get_op_tensor_name
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell', 'DumpTensor']
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell']
class SendStat(Callback):
......@@ -123,39 +121,6 @@ class DumpParamAsImage(Callback):
cv2.imwrite(fname, res.astype('uint8'))
class DumpTensor(Callback):
"""
Dump some tensors to a file.
Every step this callback fetches tensors and write them to a npz file under ``logger.LOG_DIR``.
The dump can be loaded by ``dict(np.load(filename).items())``.
"""
# TODO run as trigger
def __init__(self, names):
"""
Args:
names (list[str]): names of tensors
"""
assert isinstance(names, (list, tuple)), names
self._names = names
self._dir = logger.LOG_DIR
def _setup_graph(self):
tensors = get_tensors_by_names(self._names)
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
return self._fetch
def _after_run(self, _, rv):
results = rv.results
dic = {}
for name, val in zip(self._names, results):
dic[name] = val
fname = os.path.join(
self._dir, 'DumpTensor-{}.npz'.format(self.global_step))
np.savez(fname, **dic)
try:
import cv2
except ImportError:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment