Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
87fad54b
Commit
87fad54b
authored
Mar 04, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix bug in resnet; improve logs for #1100
parent
ed1030b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
7 deletions
+10
-7
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+1
-1
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+9
-6
No files found.
examples/ResNet/imagenet-resnet.py
View file @
87fad54b
...
@@ -135,7 +135,7 @@ if __name__ == '__main__':
...
@@ -135,7 +135,7 @@ if __name__ == '__main__':
model
=
Model
(
args
.
depth
,
args
.
mode
)
model
=
Model
(
args
.
depth
,
args
.
mode
)
model
.
data_format
=
args
.
data_format
model
.
data_format
=
args
.
data_format
if
model
.
weight_decay_norm
:
if
args
.
weight_decay_norm
:
model
.
weight_decay_pattern
=
".*/W|.*/gamma|.*/beta"
model
.
weight_decay_pattern
=
".*/W|.*/gamma|.*/beta"
if
args
.
eval
:
if
args
.
eval
:
...
...
tensorpack/tfutils/model_utils.py
View file @
87fad54b
...
@@ -40,10 +40,11 @@ def describe_trainable_vars():
...
@@ -40,10 +40,11 @@ def describe_trainable_vars():
total
+=
ele
total
+=
ele
total_bytes
+=
ele
*
v
.
dtype
.
size
total_bytes
+=
ele
*
v
.
dtype
.
size
data
.
append
([
v
.
name
,
shape
,
ele
,
v
.
device
,
v
.
dtype
.
base_dtype
.
name
])
data
.
append
([
v
.
name
,
shape
,
ele
,
v
.
device
,
v
.
dtype
.
base_dtype
.
name
])
headers
=
[
'name'
,
'shape'
,
'
dim
'
,
'device'
,
'dtype'
]
headers
=
[
'name'
,
'shape'
,
'
#elements
'
,
'device'
,
'dtype'
]
dtypes
=
set
([
x
[
4
]
for
x
in
data
])
dtypes
=
list
(
set
([
x
[
4
]
for
x
in
data
]))
if
len
(
dtypes
)
==
1
:
if
len
(
dtypes
)
==
1
and
dtypes
[
0
]
==
"float32"
:
# don't log the dtype if all vars are float32 (default dtype)
for
x
in
data
:
for
x
in
data
:
del
x
[
4
]
del
x
[
4
]
del
headers
[
4
]
del
headers
[
4
]
...
@@ -59,9 +60,11 @@ def describe_trainable_vars():
...
@@ -59,9 +60,11 @@ def describe_trainable_vars():
size_mb
=
total_bytes
/
1024.0
**
2
size_mb
=
total_bytes
/
1024.0
**
2
summary_msg
=
colored
(
summary_msg
=
colored
(
"
\n
Total #vars={}, #params={}, size={:.02f}MB"
.
format
(
"
\n
Number of trainable variables: {}"
.
format
(
len
(
data
))
+
len
(
data
),
total
,
size_mb
),
'cyan'
)
"
\n
Number of parameters (elements): {}"
.
format
(
total
)
+
logger
.
info
(
colored
(
"Trainable Variables:
\n
"
,
'cyan'
)
+
table
+
summary_msg
)
"
\n
Storage space needed for all trainable variables: {:.02f}MB"
.
format
(
size_mb
),
'cyan'
)
logger
.
info
(
colored
(
"List of Trainable Variables:
\n
"
,
'cyan'
)
+
table
+
summary_msg
)
def
get_shape_str
(
tensors
):
def
get_shape_str
(
tensors
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment