Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
931640c5
Commit
931640c5
authored
Jan 07, 2019
by
Patrick Wieschollek
Committed by
Yuxin Wu
Jan 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
enable_arg_scope for a given function (#1035)
parent
4eaeed3f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
6 deletions
+28
-6
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+28
-6
No files found.
tensorpack/tfutils/argscope.py
View file @
931640c5
...
@@ -10,7 +10,8 @@ from inspect import getmembers, isfunction
...
@@ -10,7 +10,8 @@ from inspect import getmembers, isfunction
from
..utils
import
logger
from
..utils
import
logger
from
.tower
import
get_current_tower_context
from
.tower
import
get_current_tower_context
__all__
=
[
'argscope'
,
'get_arg_scope'
,
'enable_argscope_for_module'
]
__all__
=
[
'argscope'
,
'get_arg_scope'
,
'enable_argscope_for_module'
,
'enable_argscope_for_function'
]
_ArgScopeStack
=
[]
_ArgScopeStack
=
[]
...
@@ -67,8 +68,21 @@ def get_arg_scope():
...
@@ -67,8 +68,21 @@ def get_arg_scope():
return
defaultdict
(
dict
)
return
defaultdict
(
dict
)
def
argscope_mapper
(
func
,
log_shape
=
True
):
def
enable_argscope_for_function
(
func
,
log_shape
=
True
):
"""Decorator for function to support argscope
"""Decorator for function to support argscope
Example:
.. code-block:: python
from mylib import myfunc
myfunc = enable_argscope_for_function(myfunc)
Args:
func: function which should be decorated.
log_shape (bool): print input/output shapes of each function.
Returns:
The decorated function.
"""
"""
@
wraps
(
func
)
@
wraps
(
func
)
def
wrapped_func
(
*
args
,
**
kwargs
):
def
wrapped_func
(
*
args
,
**
kwargs
):
...
@@ -82,7 +96,8 @@ def argscope_mapper(func, log_shape=True):
...
@@ -82,7 +96,8 @@ def argscope_mapper(func, log_shape=True):
if
log_shape
:
if
log_shape
:
if
(
'tower'
not
in
ctx
.
ns_name
.
lower
())
or
ctx
.
is_main_training_tower
:
if
(
'tower'
not
in
ctx
.
ns_name
.
lower
())
or
ctx
.
is_main_training_tower
:
logger
.
info
(
'
%20
s:
%20
s ->
%20
s'
%
logger
.
info
(
'
%20
s:
%20
s ->
%20
s'
%
(
name
,
in_tensor
.
shape
.
as_list
(),
out_tensor
.
shape
.
as_list
()))
(
name
,
in_tensor
.
shape
.
as_list
(),
out_tensor
.
shape
.
as_list
()))
return
out_tensor
return
out_tensor
# argscope requires this property
# argscope requires this property
...
@@ -93,12 +108,19 @@ def argscope_mapper(func, log_shape=True):
...
@@ -93,12 +108,19 @@ def argscope_mapper(func, log_shape=True):
def
enable_argscope_for_module
(
module
,
log_shape
=
True
):
def
enable_argscope_for_module
(
module
,
log_shape
=
True
):
"""
"""
Overwrite all functions of a given module to support argscope.
Overwrite all functions of a given module to support argscope.
Note that this function monkey-patches the module and therefore could have unexpected consequences.
Note that this function monkey-patches the module and therefore could
have unexpected consequences.
It has been only tested to work well with `tf.layers` module.
It has been only tested to work well with `tf.layers` module.
Example:
.. code-block:: python
import tensorflow as tf
enable_argscope_for_module(tf.layers)
Args:
Args:
log_shape (bool): print input/output shapes of each function
when called
.
log_shape (bool): print input/output shapes of each function.
"""
"""
for
name
,
obj
in
getmembers
(
module
):
for
name
,
obj
in
getmembers
(
module
):
if
isfunction
(
obj
):
if
isfunction
(
obj
):
setattr
(
module
,
name
,
argscope_mapper
(
obj
,
log_shape
=
log_shape
))
setattr
(
module
,
name
,
enable_argscope_for_function
(
obj
,
log_shape
=
log_shape
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment