Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
838b1df7
Commit
838b1df7
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add shape summary
parent
5102a8f3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
13 deletions
+46
-13
example_mnist.py
example_mnist.py
+4
-4
models/_common.py
models/_common.py
+22
-8
train.py
train.py
+1
-1
utils/__init__.py
utils/__init__.py
+16
-0
utils/summary.py
utils/summary.py
+3
-0
No files found.
example_mnist.py
View file @
838b1df7
...
...
@@ -35,15 +35,15 @@ def get_model(inputs):
keep_prob
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
DROPOUT_PROB_VAR_NAME
)
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
# add a single channel
image
=
tf
.
expand_dims
(
image
,
3
)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
padding
=
'valid'
)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
conv2
=
Conv2D
(
'conv2'
,
pool1
,
out_channel
=
32
,
kernel_shape
=
3
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
fc0
=
FullyConnected
(
'fc0'
,
conv2
,
1024
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
# fc will have activation summary by default. disable this for the output layer
...
...
models/_common.py
View file @
838b1df7
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
from
utils.summary
import
*
from
utils
import
logger
def
layer_register
(
summary_activation
=
False
):
"""
...
...
@@ -17,15 +18,28 @@ def layer_register(summary_activation=False):
args
=
args
[
1
:]
do_summary
=
kwargs
.
pop
(
'summary_activation'
,
summary_activation
)
inputs
=
args
[
0
]
if
isinstance
(
inputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
inputs
))
else
:
shape_str
=
str
(
inputs
.
get_shape
()
.
as_list
())
logger
.
info
(
"{} input: {}"
.
format
(
name
,
shape_str
))
with
tf
.
variable_scope
(
name
)
as
scope
:
ret
=
func
(
*
args
,
**
kwargs
)
if
do_summary
:
ndim
=
ret
.
get_shape
()
.
ndims
assert
ndim
>=
2
,
\
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
add_activation_summary
(
ret
,
scope
.
name
)
return
ret
outputs
=
func
(
*
args
,
**
kwargs
)
if
isinstance
(
outputs
,
list
):
shape_str
=
","
.
join
(
map
(
str
(
x
.
get_shape
()
.
as_list
()),
outputs
))
if
do_summary
:
for
x
in
outputs
:
add_activation_summary
(
x
,
scope
.
name
)
else
:
shape_str
=
str
(
outputs
.
get_shape
()
.
as_list
())
if
do_summary
:
add_activation_summary
(
outputs
,
scope
.
name
)
logger
.
info
(
"{} output: {}"
.
format
(
name
,
shape_str
))
return
outputs
return
inner
return
wrapper
...
...
@@ -35,7 +49,7 @@ def shape2d(a):
"""
if
type
(
a
)
==
int
:
return
[
a
,
a
]
if
type
(
a
)
in
[
list
,
tuple
]
:
if
isinstance
(
a
,
(
list
,
tuple
))
:
assert
len
(
a
)
==
2
return
list
(
a
)
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
...
...
train.py
View file @
838b1df7
...
...
@@ -42,13 +42,13 @@ def start_train(config):
max_epoch
=
int
(
config
[
'max_epoch'
])
# build graph
G
=
tf
.
get_default_graph
()
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
summary_model
()
global_step_var
=
G
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
...
...
utils/__init__.py
View file @
838b1df7
...
...
@@ -9,6 +9,7 @@ import time
import
sys
from
contextlib
import
contextmanager
import
logger
import
tensorflow
as
tf
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
())
...
...
@@ -29,3 +30,18 @@ def timed_operation(msg, log_start=False):
yield
logger
.
info
(
'finished {}, time={:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
def
summary_model
():
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
msg
=
[
""
]
total
=
0
for
v
in
train_vars
:
shape
=
v
.
get_shape
()
ele
=
shape
.
num_elements
()
total
+=
ele
msg
.
append
(
"{}: shape={}, dim={}"
.
format
(
v
.
name
,
shape
.
as_list
(),
ele
))
msg
.
append
(
"Total dim={}"
.
format
(
total
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
utils/summary.py
View file @
838b1df7
...
...
@@ -22,6 +22,9 @@ def add_activation_summary(x, name=None):
Summary for an activation tensor x.
If name is None, use x.name
"""
ndim
=
x
.
get_shape
()
.
ndims
assert
ndim
>=
2
,
\
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
if
name
is
None
:
name
=
x
.
name
tf
.
histogram_summary
(
name
+
'/activations'
,
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment