Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d5f3350d
Commit
d5f3350d
authored
Aug 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update resnet example to contain #139
parent
74c80d57
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
17 deletions
+27
-17
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+16
-10
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+1
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+10
-7
No files found.
examples/ResNet/imagenet-resnet.py
View file @
d5f3350d
...
@@ -86,25 +86,31 @@ def get_config(fake=False, data_format='NCHW'):
...
@@ -86,25 +86,31 @@ def get_config(fake=False, data_format='NCHW'):
if
fake
:
if
fake
:
logger
.
info
(
"For benchmark, batch size is fixed to 64 per tower."
)
logger
.
info
(
"For benchmark, batch size is fixed to 64 per tower."
)
dataset_train
=
dataset_val
=
FakeData
(
dataset_train
=
FakeData
(
[[
64
,
224
,
224
,
3
],
[
64
]],
1000
,
random
=
False
,
dtype
=
'uint8'
)
[[
64
,
224
,
224
,
3
],
[
64
]],
1000
,
random
=
False
,
dtype
=
'uint8'
)
callbacks
=
[]
else
:
else
:
logger
.
info
(
"Running on {} towers. Batch size per tower: {}"
.
format
(
nr_tower
,
BATCH_SIZE
))
logger
.
info
(
"Running on {} towers. Batch size per tower: {}"
.
format
(
nr_tower
,
BATCH_SIZE
))
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
dataset_val
=
get_data
(
'val'
)
dataset_val
=
get_data
(
'val'
)
callbacks
=
[
return
TrainConfig
(
model
=
Model
(
data_format
=
data_format
),
dataflow
=
dataset_train
,
callbacks
=
[
ModelSaver
(),
ModelSaver
(),
InferenceRunner
(
dataset_val
,
[
ClassificationError
(
'wrong-top1'
,
'val-error-top1'
),
ClassificationError
(
'wrong-top5'
,
'val-error-top5'
)]),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
30
,
1e-2
),
(
60
,
1e-3
),
(
85
,
1e-4
),
(
95
,
1e-5
),
(
105
,
1e-6
)]),
[(
30
,
1e-2
),
(
60
,
1e-3
),
(
85
,
1e-4
),
(
95
,
1e-5
),
(
105
,
1e-6
)]),
HumanHyperParamSetter
(
'learning_rate'
),
HumanHyperParamSetter
(
'learning_rate'
),
],
]
infs
=
[
ClassificationError
(
'wrong-top1'
,
'val-error-top1'
),
ClassificationError
(
'wrong-top5'
,
'val-error-top5'
)]
if
nr_tower
==
1
:
callbacks
.
append
(
InferenceRunner
(
QueueInput
(
dataset_val
),
infs
))
else
:
callbacks
.
append
(
DataParallelInferenceRunner
(
dataset_val
,
infs
,
list
(
range
(
nr_tower
))))
return
TrainConfig
(
model
=
Model
(
data_format
=
data_format
),
dataflow
=
dataset_train
,
callbacks
=
callbacks
,
steps_per_epoch
=
5000
,
steps_per_epoch
=
5000
,
max_epoch
=
110
,
max_epoch
=
110
,
nr_tower
=
nr_tower
nr_tower
=
nr_tower
...
...
tensorpack/callbacks/inference_runner.py
View file @
d5f3350d
...
@@ -78,6 +78,7 @@ class InferenceRunnerBase(Callback):
...
@@ -78,6 +78,7 @@ class InferenceRunnerBase(Callback):
self
.
_size
=
input
.
size
()
self
.
_size
=
input
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
raise
ValueError
(
"Input used in InferenceRunner must have a size!"
)
raise
ValueError
(
"Input used in InferenceRunner must have a size!"
)
logger
.
info
(
"InferenceRunner will eval on an InputSource of size {}"
.
format
(
self
.
_size
))
if
extra_hooks
is
None
:
if
extra_hooks
is
None
:
extra_hooks
=
[]
extra_hooks
=
[]
...
...
tensorpack/dataflow/common.py
View file @
d5f3350d
...
@@ -137,8 +137,11 @@ class BatchData(ProxyDataFlow):
...
@@ -137,8 +137,11 @@ class BatchData(ProxyDataFlow):
raise
raise
except
:
except
:
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
import
IPython
as
IP
try
:
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
# open an ipython shell if possible
import
IPython
as
IP
;
IP
.
embed
()
# noqa
except
:
pass
return
result
return
result
...
@@ -679,14 +682,14 @@ class PrintData(ProxyDataFlow):
...
@@ -679,14 +682,14 @@ class PrintData(ProxyDataFlow):
"""
"""
Dump gathered debugging information to stdout.
Dump gathered debugging information to stdout.
"""
"""
msg
=
[
""
]
label
=
""
if
self
.
name
is
None
else
" ("
+
self
.
label
+
")"
logger
.
info
(
colored
(
"DataFlow Info
%
s:"
%
label
,
'cyan'
))
for
i
,
dummy
in
enumerate
(
itertools
.
islice
(
self
.
ds
.
get_data
(),
self
.
num
)):
for
i
,
dummy
in
enumerate
(
itertools
.
islice
(
self
.
ds
.
get_data
(),
self
.
num
)):
if
isinstance
(
dummy
,
list
):
if
isinstance
(
dummy
,
list
):
msg
.
append
(
"datapoint
%
i<
%
i with
%
i components consists of"
%
(
i
,
self
.
num
,
len
(
dummy
)
))
msg
=
"datapoint
%
i<
%
i with
%
i components consists of
\n
"
%
(
i
,
self
.
num
,
len
(
dummy
))
for
k
,
entry
in
enumerate
(
dummy
):
for
k
,
entry
in
enumerate
(
dummy
):
msg
.
append
(
self
.
_analyze_input_data
(
entry
,
k
))
msg
+=
self
.
_analyze_input_data
(
entry
,
k
)
+
'
\n
'
label
=
""
if
self
.
name
is
None
else
" ("
+
self
.
label
+
")"
print
(
msg
)
logger
.
info
(
colored
(
"DataFlow Info
%
s:"
%
label
,
'cyan'
)
+
'
\n
'
.
join
(
msg
))
# reset again after print
# reset again after print
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment