Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ef0ecb80
Commit
ef0ecb80
authored
Sep 01, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update cifar-resnet
parent
52bbe706
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+14
-5
No files found.
examples/ResNet/cifar10-resnet.py
View file @
ef0ecb80
...
@@ -70,6 +70,7 @@ class Model(ModelDesc):
...
@@ -70,6 +70,7 @@ class Model(ModelDesc):
return
l
return
l
with
argscope
([
Conv2D
,
AvgPooling
,
BatchNorm
,
GlobalAvgPooling
],
data_format
=
'channels_first'
),
\
with
argscope
([
Conv2D
,
AvgPooling
,
BatchNorm
,
GlobalAvgPooling
],
data_format
=
'channels_first'
),
\
argscope
(
BatchNorm
,
virtual_batch_size
=
32
),
\
argscope
(
Conv2D
,
use_bias
=
False
,
kernel_size
=
3
,
argscope
(
Conv2D
,
use_bias
=
False
,
kernel_size
=
3
,
kernel_initializer
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
,
mode
=
'fan_out'
)):
kernel_initializer
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
,
mode
=
'fan_out'
)):
l
=
Conv2D
(
'conv0'
,
image
,
16
,
activation
=
BNReLU
)
l
=
Conv2D
(
'conv0'
,
image
,
16
,
activation
=
BNReLU
)
...
@@ -140,17 +141,21 @@ def get_data(train_or_test):
...
@@ -140,17 +141,21 @@ def get_data(train_or_test):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'-n'
,
'--num
_
units'
,
parser
.
add_argument
(
'-n'
,
'--num
-
units'
,
help
=
'number of units in each stage'
,
help
=
'number of units in each stage'
,
type
=
int
,
default
=
18
)
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'--load'
,
help
=
'load model for training'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model for training'
)
parser
.
add_argument
(
'--logdir'
,
help
=
'log directory'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
NUM_UNITS
=
args
.
num_units
NUM_UNITS
=
args
.
num_units
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
logger
.
auto_set_dir
()
if
args
.
logdir
:
logger
.
set_logger_dir
(
args
.
logdir
)
else
:
logger
.
auto_set_dir
()
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
dataset_test
=
get_data
(
'test'
)
dataset_test
=
get_data
(
'test'
)
...
@@ -163,9 +168,13 @@ if __name__ == '__main__':
...
@@ -163,9 +168,13 @@ if __name__ == '__main__':
InferenceRunner
(
dataset_test
,
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
(
'wrong_vector'
)]),
[
ScalarStats
(
'cost'
),
ClassificationError
(
'wrong_vector'
)]),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
[(
1
,
0.1
),
(
32
,
0.01
),
(
48
,
0.001
)])
],
],
max_epoch
=
400
,
# models are trained with a mini-batch size of 128 on two GPUs. We
# start with a learningrate of 0.1, divide it by 10 at 32k and 48k iterations,
# andterminate training at 64k iterations
steps_per_epoch
=
1000
,
max_epoch
=
64
,
session_init
=
SmartInit
(
args
.
load
),
session_init
=
SmartInit
(
args
.
load
),
)
)
num_gpu
=
max
(
get_num_gpu
(),
1
)
num_gpu
=
max
(
get_num_gpu
(),
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment