Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0e4ddfd6
Commit
0e4ddfd6
authored
May 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support VariableHolder.all()
parent
e5837873
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
9 deletions
+34
-9
tensorpack/models/common.py
tensorpack/models/common.py
+30
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+2
-4
tensorpack/models/fc.py
tensorpack/models/fc.py
+1
-2
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+1
-2
No files found.
tensorpack/models/common.py
View file @
0e4ddfd6
...
...
@@ -21,7 +21,36 @@ __all__ = ['layer_register', 'disable_layer_logging', 'get_registered_layer', 'V
class
VariableHolder
(
object
):
""" A proxy to access variables defined in a layer. """
pass
def
__init__
(
self
,
**
kwargs
):
"""
Args:
kwargs: {name:variable}
"""
self
.
_vars
=
{}
for
k
,
v
in
six
.
iteritems
(
kwargs
):
self
.
_add_variable
(
k
,
v
)
def
_add_variable
(
self
,
name
,
var
):
print
(
name
,
var
.
name
)
assert
name
not
in
self
.
_vars
self
.
_vars
[
name
]
=
var
def
__setattr__
(
self
,
name
,
var
):
if
not
name
.
startswith
(
'_'
):
self
.
_add_variable
(
name
,
var
)
else
:
# private attributes
super
(
VariableHolder
,
self
)
.
__setattr__
(
name
,
var
)
def
__getattr__
(
self
,
name
):
return
self
.
_vars
[
name
]
def
all
(
self
):
"""
Returns:
list of all variables
"""
return
list
(
six
.
itervalues
(
self
.
_vars
))
def
_register
(
name
,
func
):
...
...
tensorpack/models/conv2d.py
View file @
0e4ddfd6
...
...
@@ -72,8 +72,7 @@ def Conv2D(x, out_channel, kernel_shape,
conv
=
tf
.
concat
(
outputs
,
channel_axis
)
ret
=
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
.
variables
=
VariableHolder
()
ret
.
variables
.
W
=
W
ret
.
variables
=
VariableHolder
(
W
=
W
)
if
use_bias
:
ret
.
variables
.
b
=
b
return
ret
...
...
@@ -166,8 +165,7 @@ def Deconv2D(x, out_shape, kernel_shape,
conv
.
set_shape
(
tf
.
TensorShape
([
None
]
+
shp3_static
))
ret
=
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
.
variables
=
VariableHolder
()
ret
.
variables
.
W
=
W
ret
.
variables
=
VariableHolder
(
W
=
W
)
if
use_bias
:
ret
.
variables
.
b
=
b
return
ret
tensorpack/models/fc.py
View file @
0e4ddfd6
...
...
@@ -48,8 +48,7 @@ def FullyConnected(x, out_dim,
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
ret
=
nl
(
prod
,
name
=
'output'
)
ret
.
variables
=
VariableHolder
()
ret
.
variables
.
W
=
W
ret
.
variables
=
VariableHolder
(
W
=
W
)
if
use_bias
:
ret
.
variables
.
b
=
b
...
...
tensorpack/models/nonlin.py
View file @
0e4ddfd6
...
...
@@ -56,8 +56,7 @@ def PReLU(x, init=0.001, name='output'):
x
=
((
1
+
alpha
)
*
x
+
(
1
-
alpha
)
*
tf
.
abs
(
x
))
ret
=
tf
.
multiply
(
x
,
0.5
,
name
=
name
)
ret
.
variables
=
VariableHolder
()
ret
.
variables
.
alpha
=
alpha
ret
.
variables
=
VariableHolder
(
alpha
=
alpha
)
return
ret
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment