Commit 77dc71e3 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

allow verbose shape information output for tf.layers (#792)

* Improving the information when using `tf.layers` similar to the
layers provided by tensorpack (like Conv2D, etc). I don't know if
this is the perfect solution.

* try to identity tower0
parent fa34d239
...@@ -6,6 +6,8 @@ from collections import defaultdict ...@@ -6,6 +6,8 @@ from collections import defaultdict
import copy import copy
from functools import wraps from functools import wraps
from inspect import isfunction, getmembers from inspect import isfunction, getmembers
from ..utils import logger
import tensorflow as tf
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module'] __all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module']
...@@ -64,7 +66,7 @@ def get_arg_scope(): ...@@ -64,7 +66,7 @@ def get_arg_scope():
return defaultdict(dict) return defaultdict(dict)
def argscope_mapper(func): def argscope_mapper(func, log_shape=True):
"""Decorator for function to support argscope """Decorator for function to support argscope
""" """
@wraps(func) @wraps(func)
...@@ -72,13 +74,26 @@ def argscope_mapper(func): ...@@ -72,13 +74,26 @@ def argscope_mapper(func):
actual_args = copy.copy(get_arg_scope()[func.__name__]) actual_args = copy.copy(get_arg_scope()[func.__name__])
actual_args.update(kwargs) actual_args.update(kwargs)
out_tensor = func(*args, **actual_args) out_tensor = func(*args, **actual_args)
scope_name = tf.get_variable_scope().name
is_tower_scope = 'tower' in scope_name
in_tensor = args[0]
name = '<unkown>' if 'name' not in kwargs else kwargs['name']
if log_shape:
if is_tower_scope:
if 'tower0' in scope_name:
logger.info('%20s: %20s -> %20s' % (name, in_tensor.shape.as_list(), out_tensor.shape.as_list()))
else:
logger.info('%20s: %20s -> %20s' % (name, in_tensor.shape.as_list(), out_tensor.shape.as_list()))
return out_tensor return out_tensor
# argscope requires this property # argscope requires this property
wrapped_func.symbolic_function = None wrapped_func.symbolic_function = None
return wrapped_func return wrapped_func
def enable_argscope_for_module(module): def enable_argscope_for_module(module, log_shape=True):
""" """
Overwrite all functions of a given module to support argscope. Overwrite all functions of a given module to support argscope.
Note that this function monkey-patches the module and therefore could have unexpected consequences. Note that this function monkey-patches the module and therefore could have unexpected consequences.
...@@ -86,4 +101,4 @@ def enable_argscope_for_module(module): ...@@ -86,4 +101,4 @@ def enable_argscope_for_module(module):
""" """
for name, obj in getmembers(module): for name, obj in getmembers(module):
if isfunction(obj): if isfunction(obj):
setattr(module, name, argscope_mapper(obj)) setattr(module, name, argscope_mapper(obj, log_shape=log_shape))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment