Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c6c9a4ba
Commit
c6c9a4ba
authored
Nov 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add epoch_num stat before each epoch
parent
5e2c7309
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
13 deletions
+9
-13
README.md
README.md
+1
-1
examples/ResNet/README.md
examples/ResNet/README.md
+3
-2
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+0
-1
tensorpack/callbacks/stat.py
tensorpack/callbacks/stat.py
+4
-1
tensorpack/predict/common.py
tensorpack/predict/common.py
+1
-8
No files found.
README.md
View file @
c6c9a4ba
...
...
@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how
+
[
DoReFa-Net: training binary / low bitwidth CNN
](
examples/DoReFa-Net
)
+
[
InceptionV3 on ImageNet
](
examples/Inception/inceptionv3.py
)
+
[
ResNet for Cifar10 classification
](
examples/ResNet
)
+
[
ResNet for
ImageNet/
Cifar10 classification
](
examples/ResNet
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection
](
examples/HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
examples/SpatialTransformer
)
+
[
Double DQN plays Atari games
](
examples/Atari2600
)
...
...
examples/ResNet/README.md
View file @
c6c9a4ba
...
...
@@ -2,9 +2,10 @@
## imagenet-resnet.py
ImageNet training code of pre-activation ResNet. It follows the setup in
[
fb.resnet.torch
](
https://github.com/facebook/fb.resnet.torch
)
and get similar performance (with much fewer lines of code),
[
fb.resnet.torch
](
https://github.com/facebook/fb.resnet.torch
)
and gets similar performance (with much fewer lines of code).
More results to come.
| Model
(WIP)
| Top 5 Error | Top 1 Error |
| Model
| Top 5 Error | Top 1 Error |
|:-------------------|-------------|------------:|
| ResNet 18 | 10.67% | 29.50% |
| ResNet 50 | 7.13% | 24.12% |
...
...
examples/ResNet/imagenet-resnet.py
View file @
c6c9a4ba
...
...
@@ -116,7 +116,6 @@ class Model(ModelDesc):
wrong
=
prediction_incorrect
(
logits
,
label
,
5
,
name
=
'wrong-top5'
)
add_moving_summary
(
tf
.
reduce_mean
(
wrong
,
name
=
'train-error-top5'
))
# weight decay on all W of fc layers
wd_w
=
1e-4
wd_cost
=
tf
.
mul
(
wd_w
,
regularize_cost
(
'.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'l2_regularize_loss'
)
add_moving_summary
(
loss
,
wd_cost
)
...
...
tensorpack/callbacks/stat.py
View file @
c6c9a4ba
...
...
@@ -107,11 +107,14 @@ class StatPrinter(Callback):
self
.
_stat_holder
.
set_print_tag
(
self
.
print_tag
)
self
.
_stat_holder
.
add_blacklist_tag
([
'global_step'
,
'epoch_num'
])
# just try to add this stat earlier so SendStat can use
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
def
_trigger_epoch
(
self
):
# by default, add this two stat
self
.
_stat_holder
.
add_stat
(
'global_step'
,
get_global_step
())
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
)
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
class
SendStat
(
Callback
):
"""
...
...
tensorpack/predict/common.py
View file @
c6c9a4ba
...
...
@@ -27,7 +27,6 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_var_names: a list of input variable names.
:param input_data_mapping: deprecated. used to select `input_var_names` from the `InputVars` of the model.
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
...
...
@@ -47,13 +46,7 @@ class PredictConfig(object):
# inputs & outputs
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
input_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
if
input_mapping
:
raw_vars
=
self
.
model
.
get_input_vars_desc
()
self
.
input_var_names
=
[
raw_vars
[
k
]
.
name
for
k
in
input_mapping
]
logger
.
warn
(
'The option `input_data_mapping` was deprecated.
\
Use
\'
input_var_names=[{}]
\'
instead'
.
format
(
', '
.
join
(
self
.
input_var_names
)))
elif
self
.
input_var_names
is
None
:
if
self
.
input_var_names
is
None
:
# neither options is set, assume all inputs
raw_vars
=
self
.
model
.
get_input_vars_desc
()
self
.
input_var_names
=
[
k
.
name
for
k
in
raw_vars
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment