Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0e5c83b5
Commit
0e5c83b5
authored
Dec 03, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
BatchData supports dict (fix #768)
parent
98eb3db5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
34 deletions
+47
-34
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+47
-34
No files found.
tensorpack/dataflow/common.py
View file @
0e5c83b5
...
@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow):
...
@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
,
use_list
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
,
use_list
=
False
):
"""
"""
Args:
Args:
ds (DataFlow): When ``use_list=False``, the components of ``ds``
ds (DataFlow): A dataflow that produces either list or dict.
When ``use_list=False``, the components of ``ds``
must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes.
must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes.
batch_size(int): batch size
batch_size(int): batch size
remainder (bool): When the remaining datapoints in ``ds`` is not
remainder (bool): When the remaining datapoints in ``ds`` is not
...
@@ -119,42 +120,54 @@ class BatchData(ProxyDataFlow):
...
@@ -119,42 +120,54 @@ class BatchData(ProxyDataFlow):
if
self
.
remainder
and
len
(
holder
)
>
0
:
if
self
.
remainder
and
len
(
holder
)
>
0
:
yield
BatchData
.
_aggregate_batch
(
holder
,
self
.
use_list
)
yield
BatchData
.
_aggregate_batch
(
holder
,
self
.
use_list
)
@
staticmethod
def
_batch_numpy
(
data_list
):
data
=
data_list
[
0
]
if
isinstance
(
data
,
six
.
integer_types
):
dtype
=
'int32'
elif
type
(
data
)
==
bool
:
dtype
=
'bool'
elif
type
(
data
)
==
float
:
dtype
=
'float32'
elif
isinstance
(
data
,
(
six
.
binary_type
,
six
.
text_type
)):
dtype
=
'str'
else
:
try
:
dtype
=
data
.
dtype
except
AttributeError
:
raise
TypeError
(
"Unsupported type to batch: {}"
.
format
(
type
(
data
)))
try
:
return
np
.
asarray
(
data_list
,
dtype
=
dtype
)
except
Exception
as
e
:
# noqa
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
if
isinstance
(
data
,
np
.
ndarray
):
s
=
pprint
.
pformat
([
x
.
shape
for
x
in
data_list
])
logger
.
error
(
"Shape of all arrays to be batched: "
+
s
)
try
:
# open an ipython shell if possible
import
IPython
as
IP
;
IP
.
embed
()
# noqa
except
ImportError
:
pass
@
staticmethod
@
staticmethod
def
_aggregate_batch
(
data_holder
,
use_list
=
False
):
def
_aggregate_batch
(
data_holder
,
use_list
=
False
):
size
=
len
(
data_holder
[
0
])
first_dp
=
data_holder
[
0
]
result
=
[]
if
isinstance
(
first_dp
,
(
list
,
tuple
)):
for
k
in
range
(
size
):
result
=
[]
if
use_list
:
for
k
in
range
(
len
(
first_dp
)):
result
.
append
(
data_list
=
[
x
[
k
]
for
x
in
data_holder
]
[
x
[
k
]
for
x
in
data_holder
])
if
use_list
:
else
:
result
.
append
(
data_list
)
data
=
data_holder
[
0
][
k
]
if
isinstance
(
data
,
six
.
integer_types
):
dtype
=
'int32'
elif
type
(
data
)
==
bool
:
dtype
=
'bool'
elif
type
(
data
)
==
float
:
dtype
=
'float32'
elif
isinstance
(
data
,
(
six
.
binary_type
,
six
.
text_type
)):
dtype
=
'str'
else
:
else
:
try
:
result
.
append
(
BatchData
.
_batch_numpy
(
data_list
))
dtype
=
data
.
dtype
elif
isinstance
(
first_dp
,
dict
):
except
AttributeError
:
result
=
[]
raise
TypeError
(
"Unsupported type to batch: {}"
.
format
(
type
(
data
)))
for
key
in
first_dp
.
keys
():
try
:
data_list
=
[
x
[
k
]
for
x
in
data_holder
]
result
.
append
(
if
use_list
:
np
.
asarray
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
dtype
))
result
[
key
]
=
data_list
except
Exception
as
e
:
# noqa
else
:
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
result
[
key
]
=
BatchData
.
_batch_numpy
(
data_list
)
if
isinstance
(
data
,
np
.
ndarray
):
s
=
pprint
.
pformat
([
x
[
k
]
.
shape
for
x
in
data_holder
])
logger
.
error
(
"Shape of all arrays to be batched: "
+
s
)
try
:
# open an ipython shell if possible
import
IPython
as
IP
;
IP
.
embed
()
# noqa
except
ImportError
:
pass
return
result
return
result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment