Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e1cfbef8
Commit
e1cfbef8
authored
Jul 28, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-up import in ResNet
parent
08a5cf6f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
14 deletions
+12
-14
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+12
-14
No files found.
examples/ResNet/imagenet-resnet.py
View file @
e1cfbef8
...
...
@@ -2,7 +2,6 @@
# -*- coding: UTF-8 -*-
# File: imagenet-resnet.py
import
cv2
import
sys
import
argparse
import
numpy
as
np
...
...
@@ -10,9 +9,8 @@ import os
import
multiprocessing
import
tensorflow
as
tf
from
tensorflow.contrib.layers
import
variance_scaling_initializer
from
tensorpack
import
*
from
tensorpack.utils.stats
import
RatioCounter
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
...
...
@@ -24,6 +22,13 @@ TOTAL_BATCH_SIZE = 256
INPUT_SHAPE
=
224
DEPTH
=
None
RESNET_CONFIG
=
{
18
:
([
2
,
2
,
2
,
2
],
resnet_basicblock
),
34
:
([
3
,
4
,
6
,
3
],
resnet_basicblock
),
50
:
([
3
,
4
,
6
,
3
],
resnet_bottleneck
),
101
:
([
3
,
4
,
23
,
3
],
resnet_bottleneck
)
}
class
Model
(
ModelDesc
):
def
__init__
(
self
,
data_format
=
'NCHW'
):
...
...
@@ -46,23 +51,16 @@ class Model(ModelDesc):
if
self
.
data_format
==
'NCHW'
:
image
=
tf
.
transpose
(
image
,
[
0
,
3
,
1
,
2
])
cfg
=
{
18
:
([
2
,
2
,
2
,
2
],
resnet_basicblock
),
34
:
([
3
,
4
,
6
,
3
],
resnet_basicblock
),
50
:
([
3
,
4
,
6
,
3
],
resnet_bottleneck
),
101
:
([
3
,
4
,
23
,
3
],
resnet_bottleneck
)
}
defs
,
block_func
=
cfg
[
DEPTH
]
defs
,
block_func
=
RESNET_CONFIG
[
DEPTH
]
with
argscope
([
Conv2D
,
MaxPooling
,
GlobalAvgPooling
,
BatchNorm
],
data_format
=
self
.
data_format
):
logits
=
resnet_backbone
(
image
,
defs
,
block_func
)
loss
=
compute_loss_and_error
(
logits
,
label
)
wd_
cost
=
regularize_cost
(
'.*/W'
,
l2_regularizer
(
1e-4
),
name
=
'l2_regularize_loss'
)
add_moving_summary
(
loss
,
wd_
cost
)
self
.
cost
=
tf
.
add_n
([
loss
,
wd_
cost
],
name
=
'cost'
)
wd_
loss
=
regularize_cost
(
'.*/W'
,
l2_regularizer
(
1e-4
),
name
=
'l2_regularize_loss'
)
add_moving_summary
(
loss
,
wd_
loss
)
self
.
cost
=
tf
.
add_n
([
loss
,
wd_
loss
],
name
=
'cost'
)
def
_get_optimizer
(
self
):
lr
=
get_scalar_var
(
'learning_rate'
,
0.1
,
summary
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment