Commit 0e4ddfd6 authored by Yuxin Wu's avatar Yuxin Wu

support VariableHolder.all()

parent e5837873
...@@ -21,7 +21,36 @@ __all__ = ['layer_register', 'disable_layer_logging', 'get_registered_layer', 'V ...@@ -21,7 +21,36 @@ __all__ = ['layer_register', 'disable_layer_logging', 'get_registered_layer', 'V
class VariableHolder(object): class VariableHolder(object):
""" A proxy to access variables defined in a layer. """ """ A proxy to access variables defined in a layer. """
pass def __init__(self, **kwargs):
"""
Args:
kwargs: {name:variable}
"""
self._vars = {}
for k, v in six.iteritems(kwargs):
self._add_variable(k, v)
def _add_variable(self, name, var):
print(name, var.name)
assert name not in self._vars
self._vars[name] = var
def __setattr__(self, name, var):
if not name.startswith('_'):
self._add_variable(name, var)
else:
# private attributes
super(VariableHolder, self).__setattr__(name, var)
def __getattr__(self, name):
return self._vars[name]
def all(self):
"""
Returns:
list of all variables
"""
return list(six.itervalues(self._vars))
def _register(name, func): def _register(name, func):
......
...@@ -72,8 +72,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -72,8 +72,7 @@ def Conv2D(x, out_channel, kernel_shape,
conv = tf.concat(outputs, channel_axis) conv = tf.concat(outputs, channel_axis)
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output') ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder() ret.variables = VariableHolder(W=W)
ret.variables.W = W
if use_bias: if use_bias:
ret.variables.b = b ret.variables.b = b
return ret return ret
...@@ -166,8 +165,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -166,8 +165,7 @@ def Deconv2D(x, out_shape, kernel_shape,
conv.set_shape(tf.TensorShape([None] + shp3_static)) conv.set_shape(tf.TensorShape([None] + shp3_static))
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output') ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder() ret.variables = VariableHolder(W=W)
ret.variables.W = W
if use_bias: if use_bias:
ret.variables.b = b ret.variables.b = b
return ret return ret
...@@ -48,8 +48,7 @@ def FullyConnected(x, out_dim, ...@@ -48,8 +48,7 @@ def FullyConnected(x, out_dim,
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W) prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
ret = nl(prod, name='output') ret = nl(prod, name='output')
ret.variables = VariableHolder() ret.variables = VariableHolder(W=W)
ret.variables.W = W
if use_bias: if use_bias:
ret.variables.b = b ret.variables.b = b
......
...@@ -56,8 +56,7 @@ def PReLU(x, init=0.001, name='output'): ...@@ -56,8 +56,7 @@ def PReLU(x, init=0.001, name='output'):
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
ret = tf.multiply(x, 0.5, name=name) ret = tf.multiply(x, 0.5, name=name)
ret.variables = VariableHolder() ret.variables = VariableHolder(alpha=alpha)
ret.variables.alpha = alpha
return ret return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment