Commit 940a1636 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

argscope_for_function had some issues when a layer produces multiple … (#1049)

* argscope_for_function had some issues when a layer produces multiple outputs.

* add more docs
parent c366eebf
......@@ -79,13 +79,22 @@ def enable_argscope_for_function(func, log_shape=True):
myfunc = enable_argscope_for_function(myfunc)
Args:
func: function which should be decorated.
log_shape (bool): print input/output shapes of each function.
func: A function mapping one or multiple tensors to one or multiple
tensors.
log_shape (bool): Specify whether the first input resp. output tensor
shape should be printed once.
Remarks:
If the function `func` returns multiple input or output tensors,
only the first input/output tensor shape is displayed during logging.
Returns:
The decorated function.
"""
assert callable(func), "func should be a callable"
@wraps(func)
def wrapped_func(*args, **kwargs):
actual_args = copy.copy(get_arg_scope()[func.__name__])
......@@ -97,9 +106,14 @@ def enable_argscope_for_function(func, log_shape=True):
name = func.__name__ if 'name' not in kwargs else kwargs['name']
if log_shape:
if ('tower' not in ctx.ns_name.lower()) or ctx.is_main_training_tower:
# we assume the first parameter is the most interesting
if isinstance(out_tensor, tuple):
out_tensor_descr = out_tensor[0]
else:
out_tensor_descr = out_tensor
logger.info('%20s: %20s -> %20s' %
(name, in_tensor.shape.as_list(),
out_tensor.shape.as_list()))
out_tensor_descr.shape.as_list()))
return out_tensor
# argscope requires this property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment