Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d04661e3
Commit
d04661e3
authored
Apr 17, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix prefetch bug
parent
b81c2263
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
26 deletions
+7
-26
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+1
-6
examples/cifar10_resnet.py
examples/cifar10_resnet.py
+4
-7
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+2
-13
No files found.
examples/cifar10_convnet.py
View file @
d04661e3
...
@@ -16,17 +16,12 @@ from tensorpack.tfutils import *
...
@@ -16,17 +16,12 @@ from tensorpack.tfutils import *
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
imgaug
"""
"""
A small cifar10 convnet model.
A small cifar10 convnet model.
90
%
validation accuracy after 40k step.
90
%
validation accuracy after 40k step.
"""
"""
BATCH_SIZE
=
128
MIN_AFTER_DEQUEUE
=
int
(
50000
*
0.4
)
CAPACITY
=
MIN_AFTER_DEQUEUE
+
3
*
BATCH_SIZE
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
float32
,
[
None
,
30
,
30
,
3
],
'input'
),
return
[
InputVar
(
tf
.
float32
,
[
None
,
30
,
30
,
3
],
'input'
),
...
@@ -134,7 +129,7 @@ def get_config():
...
@@ -134,7 +129,7 @@ def get_config():
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
200
,
max_epoch
=
3
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
examples/cifar10_resnet.py
View file @
d04661e3
...
@@ -16,7 +16,6 @@ from tensorpack.tfutils import *
...
@@ -16,7 +16,6 @@ from tensorpack.tfutils import *
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
imgaug
"""
"""
CIFAR10-resnet example.
CIFAR10-resnet example.
...
@@ -45,7 +44,7 @@ class Model(ModelDesc):
...
@@ -45,7 +44,7 @@ class Model(ModelDesc):
def
_get_cost
(
self
,
input_vars
,
is_training
):
def
_get_cost
(
self
,
input_vars
,
is_training
):
image
,
label
=
input_vars
image
,
label
=
input_vars
image
=
image
/
255.0
image
=
image
/
128.0
-
1
def
conv
(
name
,
l
,
channel
,
stride
):
def
conv
(
name
,
l
,
channel
,
stride
):
return
Conv2D
(
name
,
l
,
channel
,
3
,
stride
=
stride
,
return
Conv2D
(
name
,
l
,
channel
,
3
,
stride
=
stride
,
...
@@ -117,10 +116,10 @@ class Model(ModelDesc):
...
@@ -117,10 +116,10 @@ class Model(ModelDesc):
# weight decay on all W of fc layers
# weight decay on all W of fc layers
wd_w
=
tf
.
train
.
exponential_decay
(
0.0002
,
get_global_step_var
(),
wd_w
=
tf
.
train
.
exponential_decay
(
0.0002
,
get_global_step_var
(),
480000
,
0.2
,
True
)
480000
,
0.2
,
True
)
wd_cost
=
wd_w
*
regularize_cost
(
'.*/W'
,
tf
.
nn
.
l2_loss
)
wd_cost
=
tf
.
mul
(
wd_w
,
regularize_cost
(
'.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'wd_cost'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
add_param_summary
([(
'.*/W'
,
[
'histogram'
,
'sparsity'
])])
# monitor W
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor W
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
get_data
(
train_or_test
):
def
get_data
(
train_or_test
):
...
@@ -146,8 +145,6 @@ def get_data(train_or_test):
...
@@ -146,8 +145,6 @@ def get_data(train_or_test):
ds
=
PrefetchData
(
ds
,
3
,
2
)
ds
=
PrefetchData
(
ds
,
3
,
2
)
return
ds
return
ds
def
get_config
():
def
get_config
():
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
...
@@ -170,7 +167,7 @@ def get_config():
...
@@ -170,7 +167,7 @@ def get_config():
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(
n
=
18
),
model
=
Model
(
n
=
30
),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
500
,
max_epoch
=
500
,
)
)
...
...
tensorpack/dataflow/prefetch.py
View file @
d04661e3
...
@@ -9,9 +9,6 @@ from ..utils.concurrency import ensure_procs_terminate
...
@@ -9,9 +9,6 @@ from ..utils.concurrency import ensure_procs_terminate
__all__
=
[
'PrefetchData'
]
__all__
=
[
'PrefetchData'
]
class
Sentinel
:
pass
class
PrefetchProcess
(
multiprocessing
.
Process
):
class
PrefetchProcess
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
queue
):
def
__init__
(
self
,
ds
,
queue
):
"""
"""
...
@@ -24,11 +21,9 @@ class PrefetchProcess(multiprocessing.Process):
...
@@ -24,11 +21,9 @@ class PrefetchProcess(multiprocessing.Process):
def
run
(
self
):
def
run
(
self
):
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
try
:
while
True
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
self
.
queue
.
put
(
dp
)
finally
:
self
.
queue
.
put
(
Sentinel
())
class
PrefetchData
(
ProxyDataFlow
):
class
PrefetchData
(
ProxyDataFlow
):
"""
"""
...
@@ -52,17 +47,11 @@ class PrefetchData(ProxyDataFlow):
...
@@ -52,17 +47,11 @@ class PrefetchData(ProxyDataFlow):
x
.
start
()
x
.
start
()
def
get_data
(
self
):
def
get_data
(
self
):
end_cnt
=
0
tot_cnt
=
0
tot_cnt
=
0
while
True
:
while
True
:
dp
=
self
.
queue
.
get
()
dp
=
self
.
queue
.
get
()
if
isinstance
(
dp
,
Sentinel
):
end_cnt
+=
1
if
end_cnt
==
self
.
nr_proc
:
break
continue
tot_cnt
+=
1
yield
dp
yield
dp
tot_cnt
+=
1
if
tot_cnt
==
self
.
_size
:
if
tot_cnt
==
self
.
_size
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment