Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
78c4cf13
Commit
78c4cf13
authored
May 30, 2018
by
Patrick Wieschollek
Committed by
Yuxin Wu
May 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support tf.layers in argscope (#778)
* support tf.layers in argscope * rename * "lib"->"module"; add docs * typo
parent
099975c5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
12 deletions
+41
-12
examples/basics/mnist-tflayers.py
examples/basics/mnist-tflayers.py
+13
-11
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+28
-1
No files found.
examples/basics/mnist-tflayers.py
View file @
78c4cf13
...
@@ -20,6 +20,7 @@ from tensorpack.tfutils import summary, get_current_tower_context
...
@@ -20,6 +20,7 @@ from tensorpack.tfutils import summary, get_current_tower_context
from
tensorpack.dataflow
import
dataset
from
tensorpack.dataflow
import
dataset
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
enable_argscope_for_module
(
tf
.
layers
)
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
...
@@ -38,16 +39,17 @@ class Model(ModelDesc):
...
@@ -38,16 +39,17 @@ class Model(ModelDesc):
image
=
image
*
2
-
1
# center the pixels values at zero
image
=
image
*
2
-
1
# center the pixels values at zero
l
=
tf
.
layers
.
conv2d
(
image
,
32
,
3
,
padding
=
'same'
,
activation
=
tf
.
nn
.
relu
,
name
=
'conv0'
)
with
argscope
([
tf
.
layers
.
conv2d
],
padding
=
'same'
,
activation
=
tf
.
nn
.
relu
):
l
=
tf
.
layers
.
max_pooling2d
(
l
,
2
,
2
,
padding
=
'valid'
)
l
=
tf
.
layers
.
conv2d
(
image
,
32
,
3
,
name
=
'conv0'
)
l
=
tf
.
layers
.
conv2d
(
l
,
32
,
3
,
padding
=
'same'
,
activation
=
tf
.
nn
.
relu
,
name
=
'conv1'
)
l
=
tf
.
layers
.
max_pooling2d
(
l
,
2
,
2
,
padding
=
'valid'
)
l
=
tf
.
layers
.
conv2d
(
l
,
32
,
3
,
padding
=
'same'
,
activation
=
tf
.
nn
.
relu
,
name
=
'conv2'
)
l
=
tf
.
layers
.
conv2d
(
l
,
32
,
3
,
name
=
'conv1'
)
l
=
tf
.
layers
.
max_pooling2d
(
l
,
2
,
2
,
padding
=
'valid'
)
l
=
tf
.
layers
.
conv2d
(
l
,
32
,
3
,
name
=
'conv2'
)
l
=
tf
.
layers
.
conv2d
(
l
,
32
,
3
,
padding
=
'same'
,
activation
=
tf
.
nn
.
relu
,
name
=
'conv3'
)
l
=
tf
.
layers
.
max_pooling2d
(
l
,
2
,
2
,
padding
=
'valid'
)
l
=
tf
.
layers
.
flatten
(
l
)
l
=
tf
.
layers
.
conv2d
(
l
,
32
,
3
,
name
=
'conv3'
)
l
=
tf
.
layers
.
dense
(
l
,
512
,
activation
=
tf
.
nn
.
relu
,
name
=
'fc0'
)
l
=
tf
.
layers
.
flatten
(
l
)
l
=
tf
.
layers
.
dropout
(
l
,
rate
=
0.5
,
l
=
tf
.
layers
.
dense
(
l
,
512
,
activation
=
tf
.
nn
.
relu
,
name
=
'fc0'
)
training
=
get_current_tower_context
()
.
is_training
)
l
=
tf
.
layers
.
dropout
(
l
,
rate
=
0.5
,
training
=
get_current_tower_context
()
.
is_training
)
logits
=
tf
.
layers
.
dense
(
l
,
10
,
activation
=
tf
.
identity
,
name
=
'fc1'
)
logits
=
tf
.
layers
.
dense
(
l
,
10
,
activation
=
tf
.
identity
,
name
=
'fc1'
)
tf
.
nn
.
softmax
(
logits
,
name
=
'prob'
)
# a Bx10 with probabilities
tf
.
nn
.
softmax
(
logits
,
name
=
'prob'
)
# a Bx10 with probabilities
...
@@ -60,7 +62,7 @@ class Model(ModelDesc):
...
@@ -60,7 +62,7 @@ class Model(ModelDesc):
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
# This will monitor training error (in a moving_average fashion):
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tenso
s
rboard
# 1. write the value to tensorboard
# 2. write the value to stat.json
# 2. write the value to stat.json
# 3. print the value after each epoch
# 3. print the value after each epoch
train_error
=
tf
.
reduce_mean
(
1
-
correct
,
name
=
'train_error'
)
train_error
=
tf
.
reduce_mean
(
1
-
correct
,
name
=
'train_error'
)
...
...
tensorpack/tfutils/argscope.py
View file @
78c4cf13
...
@@ -4,8 +4,10 @@
...
@@ -4,8 +4,10 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
collections
import
defaultdict
from
collections
import
defaultdict
import
copy
import
copy
from
functools
import
wraps
from
inspect
import
isfunction
,
getmembers
__all__
=
[
'argscope'
,
'get_arg_scope'
]
__all__
=
[
'argscope'
,
'get_arg_scope'
,
'enable_argscope_for_module'
]
_ArgScopeStack
=
[]
_ArgScopeStack
=
[]
...
@@ -60,3 +62,28 @@ def get_arg_scope():
...
@@ -60,3 +62,28 @@ def get_arg_scope():
return
_ArgScopeStack
[
-
1
]
return
_ArgScopeStack
[
-
1
]
else
:
else
:
return
defaultdict
(
dict
)
return
defaultdict
(
dict
)
def
argscope_mapper
(
func
):
"""Decorator for function to support argscope
"""
@
wraps
(
func
)
def
wrapped_func
(
*
args
,
**
kwargs
):
actual_args
=
copy
.
copy
(
get_arg_scope
()[
func
.
__name__
])
actual_args
.
update
(
kwargs
)
out_tensor
=
func
(
*
args
,
**
actual_args
)
return
out_tensor
# argscope requires this property
wrapped_func
.
symbolic_function
=
None
return
wrapped_func
def
enable_argscope_for_module
(
module
):
"""
Overwrite all functions of a given module to support argscope.
Note that this function monkey-patches the module and therefore could have unexpected consequences.
It has been only tested to work well with `tf.layers` module.
"""
for
name
,
obj
in
getmembers
(
module
):
if
isfunction
(
obj
):
setattr
(
module
,
name
,
argscope_mapper
(
obj
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment