Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0fa11e65
Commit
0fa11e65
authored
Oct 06, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
batchdatabyshape
parent
c51df958
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
10 deletions
+37
-10
examples/HED/hed.py
examples/HED/hed.py
+4
-6
examples/Inception/inceptionv3.py
examples/Inception/inceptionv3.py
+1
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+31
-3
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+1
-1
No files found.
examples/HED/hed.py
View file @
0fa11e65
...
...
@@ -43,8 +43,6 @@ Usage:
../../scripts/plot-point.py --legend 1,2,3,4,5,final --decay 0.8
"""
BATCH_SIZE
=
1
class
Model
(
ModelDesc
):
def
__init__
(
self
,
is_training
=
True
):
self
.
isTrain
=
is_training
...
...
@@ -163,13 +161,13 @@ def get_data(name):
imgaug
.
GaussianNoise
(),
]
ds
=
AugmentImageComponent
(
ds
,
augmentors
)
ds
=
BatchData
(
ds
,
BATCH_SIZE
,
remainder
=
not
isTrain
)
#
if isTrain:
#ds = PrefetchDataZMQ(ds, 3
)
ds
=
BatchData
ByShape
(
ds
,
8
,
idx
=
0
)
if
isTrain
:
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
def
view_data
():
ds
=
get_data
(
'train'
)
ds
=
RepeatedData
(
get_data
(
'train'
),
-
1
)
ds
.
reset_state
()
for
ims
,
edgemaps
in
ds
.
get_data
():
for
im
,
edgemap
in
zip
(
ims
,
edgemaps
):
...
...
examples/Inception/inceptionv3.py
View file @
0fa11e65
...
...
@@ -24,6 +24,7 @@ with much much fewer lines of code.
It reaches 74.5
%
single-crop validation accuracy, slightly better than the official code,
and has the same running speed as well.
The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512.
With 8 TitanX it runs about 0.45 it/s.
"""
BATCH_SIZE
=
64
...
...
tensorpack/dataflow/common.py
View file @
0fa11e65
...
...
@@ -5,7 +5,7 @@
from
__future__
import
division
import
copy
import
numpy
as
np
from
collections
import
deque
from
collections
import
deque
,
defaultdict
from
six.moves
import
range
,
map
from
.base
import
DataFlow
,
ProxyDataFlow
,
RNGDataFlow
from
..utils
import
*
...
...
@@ -13,7 +13,7 @@ from ..utils import *
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'MapData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'LocallyShuffleData'
,
'TestDataSpeed'
]
'LocallyShuffleData'
,
'TestDataSpeed'
,
'BatchDataByShape'
]
class
TestDataSpeed
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
size
=
1000
):
...
...
@@ -70,7 +70,7 @@ class BatchData(ProxyDataFlow):
holder
.
append
(
data
)
if
len
(
holder
)
==
self
.
batch_size
:
yield
BatchData
.
_aggregate_batch
(
holder
)
holder
=
[
]
del
holder
[:
]
if
self
.
remainder
and
len
(
holder
)
>
0
:
yield
BatchData
.
_aggregate_batch
(
holder
)
...
...
@@ -97,6 +97,34 @@ class BatchData(ProxyDataFlow):
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
return
result
class
BatchDataByShape
(
BatchData
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
""" Group datapoint of the same shape together to batches
:param ds: a DataFlow instance. Its component must be either a scalar or a numpy array
:param idx: dp[idx] will be used to group datapoints. Other component
in dp are assumed to have the same shape.
"""
super
(
BatchDataByShape
,
self
)
.
__init__
(
ds
,
batch_size
,
remainder
=
False
)
self
.
idx
=
idx
def
size
(
self
):
raise
NotImplementedError
()
def
reset_state
(
self
):
super
(
BatchDataByShape
,
self
)
.
reset_state
()
self
.
holder
=
defaultdict
(
list
)
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
shp
=
dp
[
self
.
idx
]
.
shape
print
(
shp
,
len
(
self
.
holder
))
holder
=
self
.
holder
[
shp
]
holder
.
append
(
dp
)
if
len
(
holder
)
==
self
.
batch_size
:
yield
BatchData
.
_aggregate_batch
(
holder
)
del
holder
[:]
class
FixedSizeData
(
ProxyDataFlow
):
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
0fa11e65
...
...
@@ -100,7 +100,7 @@ class ILSVRC12(RNGDataFlow):
If is 'original' then keep the original decompressed dir with list
of image files (as below). If equals to 'train', use the `train/` dir
structure with class name as subdirectories.
:param include_bb: Include the bounding box.
U
seful in training.
:param include_bb: Include the bounding box.
Maybe u
seful in training.
Dir should have the following structure:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment