Commit bee6798b authored by Yuxin Wu's avatar Yuxin Wu

fix some layers

parent 6c4f2351
...@@ -106,14 +106,14 @@ def layer_register( ...@@ -106,14 +106,14 @@ def layer_register(
name, inputs = args[0], args[1] name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func args = args[1:] # actual positional args used to call func
assert isinstance(name, six.string_types), name assert isinstance(name, six.string_types), name
elif use_scope is False: else:
assert not log_shape
inputs = args[0]
name = None
assert not isinstance(args[0], six.string_types), name
else: # use_scope is None
assert not log_shape assert not log_shape
if isinstance(args[0], six.string_types): if isinstance(args[0], six.string_types):
if use_scope is False:
logger.warn(
"Please call layer {} without the first scope name argument, "
"or register the layer with use_scope=None to allow "
"two calling methods.".format(func.__name__))
name, inputs = args[0], args[1] name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func args = args[1:] # actual positional args used to call func
else: else:
......
...@@ -45,7 +45,7 @@ def sample(img, coords): ...@@ -45,7 +45,7 @@ def sample(img, coords):
return sampled return sampled
@layer_register() @layer_register(log_shape=True)
def ImageSample(inputs, borderMode='repeat'): def ImageSample(inputs, borderMode='repeat'):
""" """
Sample the template image using the given coordinate, by bilinear interpolation. Sample the template image using the given coordinate, by bilinear interpolation.
......
...@@ -80,7 +80,7 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -80,7 +80,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
return None return None
@layer_register(log_shape=False, use_scope=False) @layer_register(use_scope=None)
def Dropout(x, keep_prob=0.5, is_training=None, noise_shape=None): def Dropout(x, keep_prob=0.5, is_training=None, noise_shape=None):
""" """
Dropout layer as in the paper `Dropout: a Simple Way to Prevent Dropout layer as in the paper `Dropout: a Simple Way to Prevent
......
...@@ -9,7 +9,7 @@ from .common import layer_register ...@@ -9,7 +9,7 @@ from .common import layer_register
__all__ = ['ConcatWith'] __all__ = ['ConcatWith']
@layer_register(use_scope=False, log_shape=False) @layer_register(use_scope=None)
def ConcatWith(x, tensor, dim): def ConcatWith(x, tensor, dim):
""" """
A wrapper around ``tf.concat`` to cooperate with :class:`LinearWrap`. A wrapper around ``tf.concat`` to cooperate with :class:`LinearWrap`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment