Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7b91ba1d
Commit
7b91ba1d
authored
Aug 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
write some more comments in dataflow
parent
bee6798b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
52 deletions
+65
-52
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+7
-7
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+50
-39
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+7
-6
tensorpack/models/common.py
tensorpack/models/common.py
+1
-0
No files found.
tensorpack/dataflow/base.py
View file @
7b91ba1d
...
@@ -13,9 +13,10 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
...
@@ -13,9 +13,10 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class
DataFlowTerminated
(
BaseException
):
class
DataFlowTerminated
(
BaseException
):
"""
"""
An exception indicating that the DataFlow is unable to produce any more data:
An exception indicating that the DataFlow is unable to produce any more
calling :meth:`get_data` will not give a valid iterator any more.
data, i.e. something wrong happened so that calling :meth:`get_data`
In most DataFlow this will not be raised.
cannot give a valid iterator any more.
In most DataFlow this will never be raised.
"""
"""
pass
pass
...
@@ -80,7 +81,9 @@ class RNGDataFlow(DataFlow):
...
@@ -80,7 +81,9 @@ class RNGDataFlow(DataFlow):
class
ProxyDataFlow
(
DataFlow
):
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another"""
""" Base class for DataFlow that proxies another.
Every method is proxied to ``self.ds`` unless override by subclass.
"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
"""
...
@@ -90,9 +93,6 @@ class ProxyDataFlow(DataFlow):
...
@@ -90,9 +93,6 @@ class ProxyDataFlow(DataFlow):
self
.
ds
=
ds
self
.
ds
=
ds
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
Reset state of the proxied DataFlow.
"""
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
def
size
(
self
):
def
size
(
self
):
...
...
tensorpack/dataflow/common.py
View file @
7b91ba1d
...
@@ -32,7 +32,7 @@ class TestDataSpeed(ProxyDataFlow):
...
@@ -32,7 +32,7 @@ class TestDataSpeed(ProxyDataFlow):
self
.
test_size
=
size
self
.
test_size
=
size
def
get_data
(
self
):
def
get_data
(
self
):
""" Will
start
testing at the beginning, then produce data normally. """
""" Will
run
testing at the beginning, then produce data normally. """
self
.
start_test
()
self
.
start_test
()
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
...
@@ -62,22 +62,23 @@ class BatchData(ProxyDataFlow):
...
@@ -62,22 +62,23 @@ class BatchData(ProxyDataFlow):
Concat datapoints into batches.
Concat datapoints into batches.
It produces datapoints of the same number of components as ``ds``, but
It produces datapoints of the same number of components as ``ds``, but
each component has one new extra dimension of size ``batch_size``.
each component has one new extra dimension of size ``batch_size``.
The new component can be a list of the original datapoints, or an ndarray
A batch can be either a list of original components, or (by default)
of the original datapoi
nts.
a numpy array of original compone
nts.
"""
"""
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
,
use_list
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
,
use_list
=
False
):
"""
"""
Args:
Args:
ds (DataFlow): Its components must be either scalars or :class:`np.ndarray`.
ds (DataFlow): When ``use_list=False``, the components of ``ds``
Each component has to be of the same shape across datapoints.
must be either scalars or :class:`np.ndarray`, and
components has to have consistent shape across ``ds``.
batch_size(int): batch size
batch_size(int): batch size
remainder (bool):
whether to return the remaining data smaller than a batch_size.
remainder (bool):
When the remaining datapoints in ``ds`` is not
If set True, it will possibly generates a data point of a smaller batch size.
enough to form a batch, whether or not to also produce the remaining
Otherwise, all generated data are guranteed to have the same size
.
data as a smaller batch
.
use_list (bool): if True, it will run faster by producing a list
If set to False, all generated datapoints are guranteed to have the same batch size.
of datapoints instead of an ndarray of datapoints, avoiding an
use_list (bool): if True, each component will contain a list
extra copy.
of datapoints instead of an numpy array of datapoints. This also avoids an
extra copy.
"""
"""
super
(
BatchData
,
self
)
.
__init__
(
ds
)
super
(
BatchData
,
self
)
.
__init__
(
ds
)
if
not
remainder
:
if
not
remainder
:
...
@@ -152,9 +153,10 @@ class BatchDataByShape(BatchData):
...
@@ -152,9 +153,10 @@ class BatchDataByShape(BatchData):
datapoints of different shape, and batches will be formed from those who
datapoints of different shape, and batches will be formed from those who
have the same shape.
have the same shape.
It is implemented by a dict{shape -> datapoints}.
Note:
Datapoints of uncommon shapes may never be enough to form a batch and
It is implemented by a dict{shape -> datapoints}.
never get generated.
Datapoints of uncommon shapes may never be enough to form a batch and
never get generated.
"""
"""
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
"""
"""
...
@@ -184,8 +186,8 @@ class BatchDataByShape(BatchData):
...
@@ -184,8 +186,8 @@ class BatchDataByShape(BatchData):
class
FixedSizeData
(
ProxyDataFlow
):
class
FixedSizeData
(
ProxyDataFlow
):
""" Generate data from another DataFlow, but with a fixed
size
.
""" Generate data from another DataFlow, but with a fixed
total count
.
The iterator of the underlying DataFlow will be kept if not exhausted.
The iterator
state
of the underlying DataFlow will be kept if not exhausted.
"""
"""
def
__init__
(
self
,
ds
,
size
):
def
__init__
(
self
,
ds
,
size
):
"""
"""
...
@@ -220,19 +222,21 @@ class FixedSizeData(ProxyDataFlow):
...
@@ -220,19 +222,21 @@ class FixedSizeData(ProxyDataFlow):
class
MapData
(
ProxyDataFlow
):
class
MapData
(
ProxyDataFlow
):
""" Apply a mapper/filter on the DataFlow"""
"""
Apply a mapper/filter on the DataFlow.
Note:
1. Please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def
__init__
(
self
,
ds
,
func
):
def
__init__
(
self
,
ds
,
func
):
"""
"""
Args:
Args:
ds (DataFlow): input DataFlow
ds (DataFlow): input DataFlow
func (datapoint -> datapoint | None): takes a datapoint and returns a new
func (datapoint -> datapoint | None): takes a datapoint and returns a new
datapoint. Return None to discard this data point.
datapoint. Return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Please make sure func doesn't modify the components
unless you're certain it's safe.
"""
"""
super
(
MapData
,
self
)
.
__init__
(
ds
)
super
(
MapData
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
self
.
func
=
func
...
@@ -245,20 +249,23 @@ class MapData(ProxyDataFlow):
...
@@ -245,20 +249,23 @@ class MapData(ProxyDataFlow):
class
MapDataComponent
(
MapData
):
class
MapDataComponent
(
MapData
):
""" Apply a mapper/filter on a datapoint component"""
"""
Apply a mapper/filter on a datapoint component.
Note:
1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
"""
"""
Args:
Args:
ds (DataFlow): input DataFlow.
ds (DataFlow): input DataFlow.
func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``.
func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``.
return None to discard this datapoint.
return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component.
index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
"""
"""
def
f
(
dp
):
def
f
(
dp
):
r
=
func
(
dp
[
index
])
r
=
func
(
dp
[
index
])
...
@@ -272,7 +279,8 @@ class MapDataComponent(MapData):
...
@@ -272,7 +279,8 @@ class MapDataComponent(MapData):
class
RepeatedData
(
ProxyDataFlow
):
class
RepeatedData
(
ProxyDataFlow
):
""" Take data points from another DataFlow and produce them until
""" Take data points from another DataFlow and produce them until
it's exhausted for certain amount of times.
it's exhausted for certain amount of times. i.e.:
dp1, dp2, .... dpn, dp1, dp2, ....dpn
"""
"""
def
__init__
(
self
,
ds
,
nr
):
def
__init__
(
self
,
ds
,
nr
):
...
@@ -306,9 +314,9 @@ class RepeatedData(ProxyDataFlow):
...
@@ -306,9 +314,9 @@ class RepeatedData(ProxyDataFlow):
class
RepeatedDataPoint
(
ProxyDataFlow
):
class
RepeatedDataPoint
(
ProxyDataFlow
):
""" Take data points from another DataFlow and produce them a
""" Take data points from another DataFlow and produce them a
certain number of times dp1, ..., dp1, dp2, ..., dp2, ...
certain number of times. i.e.:
dp1, dp1, ..., dp1, dp2, ..., dp2, ...
"""
"""
def
__init__
(
self
,
ds
,
nr
):
def
__init__
(
self
,
ds
,
nr
):
...
@@ -408,7 +416,9 @@ class RandomMixData(RNGDataFlow):
...
@@ -408,7 +416,9 @@ class RandomMixData(RNGDataFlow):
class
ConcatData
(
DataFlow
):
class
ConcatData
(
DataFlow
):
"""
"""
Concatenate several dataflows. Produce datapoints from them one by one.
Concatenate several DataFlow.
Produce datapoints from each DataFlow and go to the next when one
DataFlow is exhausted.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
...
@@ -439,9 +449,9 @@ class JoinData(DataFlow):
...
@@ -439,9 +449,9 @@ class JoinData(DataFlow):
.. code-block:: none
.. code-block:: none
d
p1
: [c1, c2]
d
f1 produces
: [c1, c2]
d
p2
: [c3, c4]
d
f2 produces
: [c3, c4]
join: [c1, c2, c3, c4]
join
ed
: [c1, c2, c3, c4]
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
...
@@ -495,9 +505,9 @@ def SelectComponent(ds, idxs):
...
@@ -495,9 +505,9 @@ def SelectComponent(ds, idxs):
.. code-block:: none
.. code-block:: none
dp
: [c1, c2, c3]
original df produces
: [c1, c2, c3]
idxs: [2,1]
idxs: [2,1]
output dp
: [c3, c2]
this df
: [c3, c2]
"""
"""
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
...
@@ -561,7 +571,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
...
@@ -561,7 +571,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
class
CacheData
(
ProxyDataFlow
):
class
CacheData
(
ProxyDataFlow
):
"""
"""
Cache a dataflow completely in memory.
Cache the first pass of a DataFlow completely in memory,
and produce from the cache thereafter.
"""
"""
def
__init__
(
self
,
ds
,
shuffle
=
False
):
def
__init__
(
self
,
ds
,
shuffle
=
False
):
"""
"""
...
...
tensorpack/dataflow/format.py
View file @
7b91ba1d
...
@@ -29,9 +29,9 @@ class HDF5Data(RNGDataFlow):
...
@@ -29,9 +29,9 @@ class HDF5Data(RNGDataFlow):
Zip data from different paths in an HDF5 file.
Zip data from different paths in an HDF5 file.
Warning:
Warning:
The current implementation will load all data into memory.
The current implementation will load all data into memory.
(TODO)
"""
"""
# TODO lazy load
# TODO
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
"""
...
@@ -61,7 +61,8 @@ class HDF5Data(RNGDataFlow):
...
@@ -61,7 +61,8 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
class
LMDBData
(
RNGDataFlow
):
""" Read a LMDB database and produce (k,v) pairs """
""" Read a LMDB database and produce (k,v) raw string pairs.
"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
"""
"""
Args:
Args:
...
@@ -79,7 +80,7 @@ class LMDBData(RNGDataFlow):
...
@@ -79,7 +80,7 @@ class LMDBData(RNGDataFlow):
self
.
_lmdb_path
=
lmdb_path
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
self
.
open_lmdb
()
self
.
_
open_lmdb
()
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_set_keys
(
keys
)
self
.
_set_keys
(
keys
)
logger
.
info
(
"Found {} entries in {}"
.
format
(
self
.
_size
,
self
.
_lmdb_path
))
logger
.
info
(
"Found {} entries in {}"
.
format
(
self
.
_size
,
self
.
_lmdb_path
))
...
@@ -113,7 +114,7 @@ class LMDBData(RNGDataFlow):
...
@@ -113,7 +114,7 @@ class LMDBData(RNGDataFlow):
else
:
else
:
self
.
keys
=
keys
self
.
keys
=
keys
def
open_lmdb
(
self
):
def
_
open_lmdb
(
self
):
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
readonly
=
True
,
lock
=
False
,
readahead
=
True
,
readonly
=
True
,
lock
=
False
,
readahead
=
True
,
...
@@ -123,7 +124,7 @@ class LMDBData(RNGDataFlow):
...
@@ -123,7 +124,7 @@ class LMDBData(RNGDataFlow):
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
_lmdb
.
close
()
self
.
_lmdb
.
close
()
super
(
LMDBData
,
self
)
.
reset_state
()
super
(
LMDBData
,
self
)
.
reset_state
()
self
.
open_lmdb
()
self
.
_
open_lmdb
()
def
size
(
self
):
def
size
(
self
):
return
self
.
_size
return
self
.
_size
...
...
tensorpack/models/common.py
View file @
7b91ba1d
...
@@ -131,6 +131,7 @@ def layer_register(
...
@@ -131,6 +131,7 @@ def layer_register(
if
name
is
not
None
:
# use scope
if
name
is
not
None
:
# use scope
with
tf
.
variable_scope
(
name
)
as
scope
:
with
tf
.
variable_scope
(
name
)
as
scope
:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
if
do_log_shape
:
if
do_log_shape
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment