Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7b91ba1d
Commit
7b91ba1d
authored
Aug 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
write some more comments in dataflow
parent
bee6798b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
52 deletions
+65
-52
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+7
-7
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+50
-39
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+7
-6
tensorpack/models/common.py
tensorpack/models/common.py
+1
-0
No files found.
tensorpack/dataflow/base.py
View file @
7b91ba1d
...
...
@@ -13,9 +13,10 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class
DataFlowTerminated
(
BaseException
):
"""
An exception indicating that the DataFlow is unable to produce any more data:
calling :meth:`get_data` will not give a valid iterator any more.
In most DataFlow this will not be raised.
An exception indicating that the DataFlow is unable to produce any more
data, i.e. something wrong happened so that calling :meth:`get_data`
cannot give a valid iterator any more.
In most DataFlow this will never be raised.
"""
pass
...
...
@@ -80,7 +81,9 @@ class RNGDataFlow(DataFlow):
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another"""
""" Base class for DataFlow that proxies another.
Every method is proxied to ``self.ds`` unless override by subclass.
"""
def
__init__
(
self
,
ds
):
"""
...
...
@@ -90,9 +93,6 @@ class ProxyDataFlow(DataFlow):
self
.
ds
=
ds
def
reset_state
(
self
):
"""
Reset state of the proxied DataFlow.
"""
self
.
ds
.
reset_state
()
def
size
(
self
):
...
...
tensorpack/dataflow/common.py
View file @
7b91ba1d
...
...
@@ -32,7 +32,7 @@ class TestDataSpeed(ProxyDataFlow):
self
.
test_size
=
size
def
get_data
(
self
):
""" Will
start
testing at the beginning, then produce data normally. """
""" Will
run
testing at the beginning, then produce data normally. """
self
.
start_test
()
for
dp
in
self
.
ds
.
get_data
():
yield
dp
...
...
@@ -62,22 +62,23 @@ class BatchData(ProxyDataFlow):
Concat datapoints into batches.
It produces datapoints of the same number of components as ``ds``, but
each component has one new extra dimension of size ``batch_size``.
The new component can be a list of the original datapoints, or an ndarray
of the original datapoi
nts.
A batch can be either a list of original components, or (by default)
a numpy array of original compone
nts.
"""
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
,
use_list
=
False
):
"""
Args:
ds (DataFlow): Its components must be either scalars or :class:`np.ndarray`.
Each component has to be of the same shape across datapoints.
ds (DataFlow): When ``use_list=False``, the components of ``ds``
must be either scalars or :class:`np.ndarray`, and
components has to have consistent shape across ``ds``.
batch_size(int): batch size
remainder (bool):
whether to return the remaining data smaller than a batch_size.
If set True, it will possibly generates a data point of a smaller batch size.
Otherwise, all generated data are guranteed to have the same size
.
use_list (bool): if True, it will run faster by producing a list
of datapoints instead of an ndarray of datapoints, avoiding an
extra copy.
remainder (bool):
When the remaining datapoints in ``ds`` is not
enough to form a batch, whether or not to also produce the remaining
data as a smaller batch
.
If set to False, all generated datapoints are guranteed to have the same batch size.
use_list (bool): if True, each component will contain a list
of datapoints instead of an numpy array of datapoints. This also avoids an
extra copy.
"""
super
(
BatchData
,
self
)
.
__init__
(
ds
)
if
not
remainder
:
...
...
@@ -152,6 +153,7 @@ class BatchDataByShape(BatchData):
datapoints of different shape, and batches will be formed from those who
have the same shape.
Note:
It is implemented by a dict{shape -> datapoints}.
Datapoints of uncommon shapes may never be enough to form a batch and
never get generated.
...
...
@@ -184,8 +186,8 @@ class BatchDataByShape(BatchData):
class
FixedSizeData
(
ProxyDataFlow
):
""" Generate data from another DataFlow, but with a fixed
size
.
The iterator of the underlying DataFlow will be kept if not exhausted.
""" Generate data from another DataFlow, but with a fixed
total count
.
The iterator
state
of the underlying DataFlow will be kept if not exhausted.
"""
def
__init__
(
self
,
ds
,
size
):
"""
...
...
@@ -220,19 +222,21 @@ class FixedSizeData(ProxyDataFlow):
class
MapData
(
ProxyDataFlow
):
""" Apply a mapper/filter on the DataFlow"""
"""
Apply a mapper/filter on the DataFlow.
Note:
1. Please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def
__init__
(
self
,
ds
,
func
):
"""
Args:
ds (DataFlow): input DataFlow
func (datapoint -> datapoint | None): takes a datapoint and returns a new
datapoint. Return None to discard this data point.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Please make sure func doesn't modify the components
unless you're certain it's safe.
datapoint. Return None to discard this datapoint.
"""
super
(
MapData
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
...
...
@@ -245,20 +249,23 @@ class MapData(ProxyDataFlow):
class
MapDataComponent
(
MapData
):
""" Apply a mapper/filter on a datapoint component"""
"""
Apply a mapper/filter on a datapoint component.
Note:
1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
"""
Args:
ds (DataFlow): input DataFlow.
func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``.
return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
"""
def
f
(
dp
):
r
=
func
(
dp
[
index
])
...
...
@@ -272,7 +279,8 @@ class MapDataComponent(MapData):
class
RepeatedData
(
ProxyDataFlow
):
""" Take data points from another DataFlow and produce them until
it's exhausted for certain amount of times.
it's exhausted for certain amount of times. i.e.:
dp1, dp2, .... dpn, dp1, dp2, ....dpn
"""
def
__init__
(
self
,
ds
,
nr
):
...
...
@@ -306,9 +314,9 @@ class RepeatedData(ProxyDataFlow):
class
RepeatedDataPoint
(
ProxyDataFlow
):
""" Take data points from another DataFlow and produce them a
certain number of times dp1, ..., dp1, dp2, ..., dp2, ...
certain number of times. i.e.:
dp1, dp1, ..., dp1, dp2, ..., dp2, ...
"""
def
__init__
(
self
,
ds
,
nr
):
...
...
@@ -408,7 +416,9 @@ class RandomMixData(RNGDataFlow):
class
ConcatData
(
DataFlow
):
"""
Concatenate several dataflows. Produce datapoints from them one by one.
Concatenate several DataFlow.
Produce datapoints from each DataFlow and go to the next when one
DataFlow is exhausted.
"""
def
__init__
(
self
,
df_lists
):
...
...
@@ -439,9 +449,9 @@ class JoinData(DataFlow):
.. code-block:: none
d
p1
: [c1, c2]
d
p2
: [c3, c4]
join: [c1, c2, c3, c4]
d
f1 produces
: [c1, c2]
d
f2 produces
: [c3, c4]
join
ed
: [c1, c2, c3, c4]
"""
def
__init__
(
self
,
df_lists
):
...
...
@@ -495,9 +505,9 @@ def SelectComponent(ds, idxs):
.. code-block:: none
dp
: [c1, c2, c3]
original df produces
: [c1, c2, c3]
idxs: [2,1]
output dp
: [c3, c2]
this df
: [c3, c2]
"""
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
...
...
@@ -561,7 +571,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
class
CacheData
(
ProxyDataFlow
):
"""
Cache a dataflow completely in memory.
Cache the first pass of a DataFlow completely in memory,
and produce from the cache thereafter.
"""
def
__init__
(
self
,
ds
,
shuffle
=
False
):
"""
...
...
tensorpack/dataflow/format.py
View file @
7b91ba1d
...
...
@@ -29,9 +29,9 @@ class HDF5Data(RNGDataFlow):
Zip data from different paths in an HDF5 file.
Warning:
The current implementation will load all data into memory.
The current implementation will load all data into memory.
(TODO)
"""
# TODO lazy load
# TODO
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
...
...
@@ -61,7 +61,8 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
""" Read a LMDB database and produce (k,v) pairs """
""" Read a LMDB database and produce (k,v) raw string pairs.
"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
"""
Args:
...
...
@@ -79,7 +80,7 @@ class LMDBData(RNGDataFlow):
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
open_lmdb
()
self
.
_
open_lmdb
()
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_set_keys
(
keys
)
logger
.
info
(
"Found {} entries in {}"
.
format
(
self
.
_size
,
self
.
_lmdb_path
))
...
...
@@ -113,7 +114,7 @@ class LMDBData(RNGDataFlow):
else
:
self
.
keys
=
keys
def
open_lmdb
(
self
):
def
_
open_lmdb
(
self
):
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
readonly
=
True
,
lock
=
False
,
readahead
=
True
,
...
...
@@ -123,7 +124,7 @@ class LMDBData(RNGDataFlow):
def
reset_state
(
self
):
self
.
_lmdb
.
close
()
super
(
LMDBData
,
self
)
.
reset_state
()
self
.
open_lmdb
()
self
.
_
open_lmdb
()
def
size
(
self
):
return
self
.
_size
...
...
tensorpack/models/common.py
View file @
7b91ba1d
...
...
@@ -131,6 +131,7 @@ def layer_register(
if
name
is
not
None
:
# use scope
with
tf
.
variable_scope
(
name
)
as
scope
:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
if
do_log_shape
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment