Commit 9c6e39c5 authored by Yuxin Wu's avatar Yuxin Wu

more beautiful shape logging

parent 92748c90
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import copy import copy
import re import re
import collections
from functools import wraps from functools import wraps
import six import six
import tensorflow as tf import tensorflow as tf
...@@ -61,6 +62,40 @@ def disable_layer_logging(): ...@@ -61,6 +62,40 @@ def disable_layer_logging():
globals()['_LAYER_LOGGED'] = ContainEverything() globals()['_LAYER_LOGGED'] = ContainEverything()
class LayerShapeLogger():
"""
A class that logs shapes of inputs/outputs of layers,
during the possibly-nested calls to them.
"""
def __init__(self):
self.stack = collections.deque()
self.depth = 0
def _indent(self):
return " " * (self.depth * 2)
def push_inputs(self, name, message):
while len(self.stack):
item = self.stack.pop()
logger.info(self._indent() + "'{}' input: {}".format(item[0], item[1]))
self.depth += 1
self.stack.append((name, message))
def push_outputs(self, name, message):
if len(self.stack):
assert len(self.stack) == 1, self.stack
assert self.stack[-1][0] == name, self.stack
item = self.stack.pop()
logger.info(self._indent() + "'{}': {} --> {}".format(name, item[1], message))
else:
self.depth -= 1
logger.info(self._indent() + "'{}' output: {}".format(name, message))
_SHAPE_LOGGER = LayerShapeLogger()
def layer_register( def layer_register(
log_shape=False, log_shape=False,
use_scope=True): use_scope=True):
...@@ -132,15 +167,13 @@ def layer_register( ...@@ -132,15 +167,13 @@ def layer_register(
scope_name = re.sub('tower[0-9]+/', '', scope.name) scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
if do_log_shape: if do_log_shape:
logger.info("{} input: {}".format(scope.name, get_shape_str(inputs))) _SHAPE_LOGGER.push_inputs(scope.name, get_shape_str(inputs))
# run the actual function # run the actual function
outputs = func(*args, **actual_args) outputs = func(*args, **actual_args)
if do_log_shape: if do_log_shape:
# log shape info and add activation _SHAPE_LOGGER.push_outputs(scope.name, get_shape_str(outputs))
logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope_name) _LAYER_LOGGED.add(scope_name)
else: else:
# run the actual function # run the actual function
......
...@@ -10,6 +10,7 @@ import tensorflow as tf ...@@ -10,6 +10,7 @@ import tensorflow as tf
from ..compat import is_tfv2 from ..compat import is_tfv2
from ..utils import logger from ..utils import logger
from .model_utils import get_shape_str
from .tower import get_current_tower_context from .tower import get_current_tower_context
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module', __all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module',
...@@ -108,9 +109,10 @@ def enable_argscope_for_function(func, log_shape=True): ...@@ -108,9 +109,10 @@ def enable_argscope_for_function(func, log_shape=True):
out_tensor_descr = out_tensor[0] out_tensor_descr = out_tensor[0]
else: else:
out_tensor_descr = out_tensor out_tensor_descr = out_tensor
logger.info('%20s: %20s -> %20s' % logger.info("{:<12}: {} --> {}".format(
(name, in_tensor.shape.as_list(), "'" + name + "'",
out_tensor_descr.shape.as_list())) get_shape_str(in_tensor),
get_shape_str(out_tensor_descr)))
return out_tensor return out_tensor
wrapped_func.__argscope_enabled__ = True wrapped_func.__argscope_enabled__ = True
......
...@@ -79,9 +79,8 @@ def get_shape_str(tensors): ...@@ -79,9 +79,8 @@ def get_shape_str(tensors):
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
for v in tensors: for v in tensors:
assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v)) assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v))
shape_str = ",".join( shape_str = ", ".join(map(get_shape_str, tensors))
map(lambda x: str(x.get_shape().as_list()), tensors))
else: else:
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors)) assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list()) shape_str = str(tensors.get_shape().as_list()).replace("None", "?")
return shape_str return shape_str
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment