Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9c6e39c5
Commit
9c6e39c5
authored
Jul 17, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
more beautiful shape logging
parent
92748c90
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
10 deletions
+44
-10
tensorpack/models/registry.py
tensorpack/models/registry.py
+37
-4
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+5
-3
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+2
-3
No files found.
tensorpack/models/registry.py
View file @
9c6e39c5
...
...
@@ -4,6 +4,7 @@
import
copy
import
re
import
collections
from
functools
import
wraps
import
six
import
tensorflow
as
tf
...
...
@@ -61,6 +62,40 @@ def disable_layer_logging():
globals
()[
'_LAYER_LOGGED'
]
=
ContainEverything
()
class
LayerShapeLogger
():
"""
A class that logs shapes of inputs/outputs of layers,
during the possibly-nested calls to them.
"""
def
__init__
(
self
):
self
.
stack
=
collections
.
deque
()
self
.
depth
=
0
def
_indent
(
self
):
return
" "
*
(
self
.
depth
*
2
)
def
push_inputs
(
self
,
name
,
message
):
while
len
(
self
.
stack
):
item
=
self
.
stack
.
pop
()
logger
.
info
(
self
.
_indent
()
+
"'{}' input: {}"
.
format
(
item
[
0
],
item
[
1
]))
self
.
depth
+=
1
self
.
stack
.
append
((
name
,
message
))
def
push_outputs
(
self
,
name
,
message
):
if
len
(
self
.
stack
):
assert
len
(
self
.
stack
)
==
1
,
self
.
stack
assert
self
.
stack
[
-
1
][
0
]
==
name
,
self
.
stack
item
=
self
.
stack
.
pop
()
logger
.
info
(
self
.
_indent
()
+
"'{}': {} --> {}"
.
format
(
name
,
item
[
1
],
message
))
else
:
self
.
depth
-=
1
logger
.
info
(
self
.
_indent
()
+
"'{}' output: {}"
.
format
(
name
,
message
))
_SHAPE_LOGGER
=
LayerShapeLogger
()
def
layer_register
(
log_shape
=
False
,
use_scope
=
True
):
...
...
@@ -132,15 +167,13 @@ def layer_register(
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
if
do_log_shape
:
logger
.
info
(
"{} input: {}"
.
format
(
scope
.
name
,
get_shape_str
(
inputs
)
))
_SHAPE_LOGGER
.
push_inputs
(
scope
.
name
,
get_shape_str
(
inputs
))
# run the actual function
outputs
=
func
(
*
args
,
**
actual_args
)
if
do_log_shape
:
# log shape info and add activation
logger
.
info
(
"{} output: {}"
.
format
(
scope
.
name
,
get_shape_str
(
outputs
)))
_SHAPE_LOGGER
.
push_outputs
(
scope
.
name
,
get_shape_str
(
outputs
))
_LAYER_LOGGED
.
add
(
scope_name
)
else
:
# run the actual function
...
...
tensorpack/tfutils/argscope.py
View file @
9c6e39c5
...
...
@@ -10,6 +10,7 @@ import tensorflow as tf
from
..compat
import
is_tfv2
from
..utils
import
logger
from
.model_utils
import
get_shape_str
from
.tower
import
get_current_tower_context
__all__
=
[
'argscope'
,
'get_arg_scope'
,
'enable_argscope_for_module'
,
...
...
@@ -108,9 +109,10 @@ def enable_argscope_for_function(func, log_shape=True):
out_tensor_descr
=
out_tensor
[
0
]
else
:
out_tensor_descr
=
out_tensor
logger
.
info
(
'
%20
s:
%20
s ->
%20
s'
%
(
name
,
in_tensor
.
shape
.
as_list
(),
out_tensor_descr
.
shape
.
as_list
()))
logger
.
info
(
"{:<12}: {} --> {}"
.
format
(
"'"
+
name
+
"'"
,
get_shape_str
(
in_tensor
),
get_shape_str
(
out_tensor_descr
)))
return
out_tensor
wrapped_func
.
__argscope_enabled__
=
True
...
...
tensorpack/tfutils/model_utils.py
View file @
9c6e39c5
...
...
@@ -79,9 +79,8 @@ def get_shape_str(tensors):
if
isinstance
(
tensors
,
(
list
,
tuple
)):
for
v
in
tensors
:
assert
isinstance
(
v
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
v
))
shape_str
=
","
.
join
(
map
(
lambda
x
:
str
(
x
.
get_shape
()
.
as_list
()),
tensors
))
shape_str
=
", "
.
join
(
map
(
get_shape_str
,
tensors
))
else
:
assert
isinstance
(
tensors
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
tensors
))
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
.
replace
(
"None"
,
"?"
)
return
shape_str
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment