Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
90dd3ef4
Commit
90dd3ef4
authored
Mar 27, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update resnet with better init
parent
2d720b60
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
6 deletions
+8
-6
examples/cifar10_resnet.py
examples/cifar10_resnet.py
+4
-4
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+3
-0
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+0
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
No files found.
examples/cifar10_resnet.py
View file @
90dd3ef4
...
...
@@ -46,7 +46,7 @@ class Model(ModelDesc):
def
conv
(
name
,
l
,
channel
,
stride
):
return
Conv2D
(
name
,
l
,
channel
,
3
,
stride
=
stride
,
nl
=
tf
.
identity
,
use_bias
=
False
,
W_init
=
tf
.
contrib
.
layers
.
xavier_initializer_conv2d
(
False
))
W_init
=
tf
.
random_normal_initializer
(
stddev
=
2.0
/
9
/
channel
))
def
residual
(
name
,
l
,
increase_dim
=
False
):
shape
=
l
.
get_shape
()
.
as_list
()
...
...
@@ -124,7 +124,6 @@ def get_data(train_or_test):
imgaug
.
RandomCrop
((
32
,
32
)),
imgaug
.
Flip
(
horiz
=
True
),
imgaug
.
BrightnessAdd
(
20
),
imgaug
.
Contrast
((
0.6
,
1.4
)),
imgaug
.
MapImage
(
lambda
x
:
x
-
pp_mean
),
]
else
:
...
...
@@ -147,7 +146,8 @@ def get_config():
sess_config
=
get_default_sess_config
(
0.9
)
lr
=
tf
.
Variable
(
0.1
,
trainable
=
False
,
name
=
'learning_rate'
)
# warm up with small LR for 1 epoch
lr
=
tf
.
Variable
(
0.01
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
...
...
@@ -158,7 +158,7 @@ def get_config():
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0001
)])
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0001
)])
]),
session_config
=
sess_config
,
model
=
Model
(
n
=
18
),
...
...
tensorpack/callbacks/base.py
View file @
90dd3ef4
...
...
@@ -56,6 +56,9 @@ class Callback(object):
return
self
.
trainer
.
global_step
def
trigger_epoch
(
self
):
"""
epoch_num is the number of epoch finished.
"""
self
.
epoch_num
+=
1
self
.
_trigger_epoch
()
...
...
tensorpack/callbacks/param.py
View file @
90dd3ef4
...
...
@@ -25,7 +25,6 @@ class HyperParamSetter(Callback):
def
_before_train
(
self
):
all_vars
=
tf
.
all_variables
()
for
v
in
all_vars
:
print
v
.
name
if
v
.
name
==
self
.
var_name
:
self
.
var
=
v
break
...
...
tensorpack/train/base.py
View file @
90dd3ef4
...
...
@@ -49,7 +49,7 @@ class Trainer(object):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph_def
=
self
.
sess
.
graph
_def
)
logger
.
LOG_DIR
,
graph_def
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
[])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment