Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
27841032
Commit
27841032
authored
Mar 31, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add TFRecordData (#174)
parent
f3d290cc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
9 deletions
+51
-9
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+6
-7
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+26
-2
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+19
-0
No files found.
tensorpack/dataflow/dftools.py
View file @
27841032
...
@@ -129,13 +129,6 @@ def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
...
@@ -129,13 +129,6 @@ def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
db
.
close
()
db
.
close
()
from
..utils.develop
import
create_dummy_func
# noqa
try
:
import
lmdb
except
ImportError
:
dump_dataflow_to_lmdb
=
create_dummy_func
(
'dump_dataflow_to_lmdb'
,
'lmdb'
)
# noqa
def
dump_dataflow_to_tfrecord
(
df
,
path
):
def
dump_dataflow_to_tfrecord
(
df
,
path
):
"""
"""
Dump all datapoints of a Dataflow to a TensorFlow TFRecord file,
Dump all datapoints of a Dataflow to a TensorFlow TFRecord file,
...
@@ -151,6 +144,12 @@ def dump_dataflow_to_tfrecord(df, path):
...
@@ -151,6 +144,12 @@ def dump_dataflow_to_tfrecord(df, path):
writer
.
write
(
dumps
(
dp
))
writer
.
write
(
dumps
(
dp
))
from
..utils.develop
import
create_dummy_func
# noqa
try
:
import
lmdb
except
ImportError
:
dump_dataflow_to_lmdb
=
create_dummy_func
(
'dump_dataflow_to_lmdb'
,
'lmdb'
)
# noqa
try
:
try
:
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ImportError
:
except
ImportError
:
...
...
tensorpack/dataflow/format.py
View file @
27841032
...
@@ -12,11 +12,11 @@ from ..utils.timer import timed_operation
...
@@ -12,11 +12,11 @@ from ..utils.timer import timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.serialize
import
loads
from
..utils.serialize
import
loads
from
..utils.argtools
import
log_once
from
..utils.argtools
import
log_once
from
.base
import
RNGDataFlow
from
.base
import
RNGDataFlow
,
DataFlow
from
.common
import
MapData
from
.common
import
MapData
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB'
,
'SVMLightData'
]
'CaffeLMDB'
,
'SVMLightData'
,
'TFRecordData'
]
"""
"""
Adapters for different data format.
Adapters for different data format.
...
@@ -228,6 +228,25 @@ class SVMLightData(RNGDataFlow):
...
@@ -228,6 +228,25 @@ class SVMLightData(RNGDataFlow):
yield
[
self
.
X
[
id
,
:],
self
.
y
[
id
]]
yield
[
self
.
X
[
id
,
:],
self
.
y
[
id
]]
class
TFRecordData
(
DataFlow
):
"""
Produce datapoints from a TFRecord file, assuming each record is
serialized by :func:`serialize.dumps`.
This class works with :func:`dftools.dump_dataflow_to_tfrecord`.
"""
def
__init__
(
self
,
path
,
size
=
None
):
self
.
_gen
=
tf
.
python_io
.
tf_record_iterator
(
path
)
self
.
_size
=
size
def
size
(
self
):
if
self
.
_size
:
return
self
.
_size
return
super
(
TFRecordData
,
self
)
.
size
()
def
get_data
(
self
):
for
dp
in
self
.
_gen
:
yield
loads
(
dp
)
from
..utils.develop
import
create_dummy_class
# noqa
from
..utils.develop
import
create_dummy_class
# noqa
try
:
try
:
import
h5py
import
h5py
...
@@ -244,3 +263,8 @@ try:
...
@@ -244,3 +263,8 @@ try:
import
sklearn.datasets
import
sklearn.datasets
except
ImportError
:
except
ImportError
:
SVMLightData
=
create_dummy_class
(
'SVMLightData'
,
'sklearn'
)
# noqa
SVMLightData
=
create_dummy_class
(
'SVMLightData'
,
'sklearn'
)
# noqa
try
:
import
tensorflow
as
tf
except
ImportError
:
TFRecordData
=
create_dummy_class
(
'TFRecordData'
,
'tensorflow'
)
# noqa
tensorpack/dataflow/raw.py
View file @
27841032
...
@@ -81,3 +81,22 @@ class DataFromList(RNGDataFlow):
...
@@ -81,3 +81,22 @@ class DataFromList(RNGDataFlow):
self
.
rng
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
for
k
in
idxs
:
yield
self
.
lst
[
k
]
yield
self
.
lst
[
k
]
class
DataFromGenerator
(
DataFlow
):
"""
Wrap a generator to a DataFlow
"""
def
__init__
(
self
,
gen
,
size
=
None
):
self
.
_gen
=
gen
self
.
_size
=
size
def
size
(
self
):
if
self
.
_size
:
return
self
.
_size
return
super
(
DataFromGenerator
,
self
)
.
size
()
def
get_data
(
self
):
# yield from
for
dp
in
self
.
_gen
:
yield
dp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment