Commit 9c6e39c5 authored by Yuxin Wu's avatar Yuxin Wu

more beautiful shape logging

parent 92748c90
......@@ -4,6 +4,7 @@
import copy
import re
import collections
from functools import wraps
import six
import tensorflow as tf
......@@ -61,6 +62,40 @@ def disable_layer_logging():
globals()['_LAYER_LOGGED'] = ContainEverything()
class LayerShapeLogger():
"""
A class that logs shapes of inputs/outputs of layers,
during the possibly-nested calls to them.
"""
def __init__(self):
self.stack = collections.deque()
self.depth = 0
def _indent(self):
return " " * (self.depth * 2)
def push_inputs(self, name, message):
while len(self.stack):
item = self.stack.pop()
logger.info(self._indent() + "'{}' input: {}".format(item[0], item[1]))
self.depth += 1
self.stack.append((name, message))
def push_outputs(self, name, message):
if len(self.stack):
assert len(self.stack) == 1, self.stack
assert self.stack[-1][0] == name, self.stack
item = self.stack.pop()
logger.info(self._indent() + "'{}': {} --> {}".format(name, item[1], message))
else:
self.depth -= 1
logger.info(self._indent() + "'{}' output: {}".format(name, message))
_SHAPE_LOGGER = LayerShapeLogger()
def layer_register(
log_shape=False,
use_scope=True):
......@@ -132,15 +167,13 @@ def layer_register(
scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
if do_log_shape:
logger.info("{} input: {}".format(scope.name, get_shape_str(inputs)))
_SHAPE_LOGGER.push_inputs(scope.name, get_shape_str(inputs))
# run the actual function
outputs = func(*args, **actual_args)
if do_log_shape:
# log shape info and add activation
logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs)))
_SHAPE_LOGGER.push_outputs(scope.name, get_shape_str(outputs))
_LAYER_LOGGED.add(scope_name)
else:
# run the actual function
......
......@@ -10,6 +10,7 @@ import tensorflow as tf
from ..compat import is_tfv2
from ..utils import logger
from .model_utils import get_shape_str
from .tower import get_current_tower_context
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module',
......@@ -108,9 +109,10 @@ def enable_argscope_for_function(func, log_shape=True):
out_tensor_descr = out_tensor[0]
else:
out_tensor_descr = out_tensor
logger.info('%20s: %20s -> %20s' %
(name, in_tensor.shape.as_list(),
out_tensor_descr.shape.as_list()))
logger.info("{:<12}: {} --> {}".format(
"'" + name + "'",
get_shape_str(in_tensor),
get_shape_str(out_tensor_descr)))
return out_tensor
wrapped_func.__argscope_enabled__ = True
......
......@@ -79,9 +79,8 @@ def get_shape_str(tensors):
if isinstance(tensors, (list, tuple)):
for v in tensors:
assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v))
shape_str = ",".join(
map(lambda x: str(x.get_shape().as_list()), tensors))
shape_str = ", ".join(map(get_shape_str, tensors))
else:
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list())
shape_str = str(tensors.get_shape().as_list()).replace("None", "?")
return shape_str
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment