Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6f9f4cd9
Commit
6f9f4cd9
authored
Jul 07, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
locallyshuffle
parent
1126fa5c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
34 deletions
+82
-34
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+1
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+22
-34
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+59
-0
No files found.
tensorpack/dataflow/base.py
View file @
6f9f4cd9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
..utils
import
get_rng
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
]
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
]
...
...
tensorpack/dataflow/common.py
View file @
6f9f4cd9
...
@@ -7,13 +7,13 @@ import copy
...
@@ -7,13 +7,13 @@ import copy
import
numpy
as
np
import
numpy
as
np
from
collections
import
deque
from
collections
import
deque
from
six.moves
import
range
,
map
from
six.moves
import
range
,
map
from
.base
import
DataFlow
,
ProxyDataFlow
from
.base
import
DataFlow
,
ProxyDataFlow
,
RNGDataFlow
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'
FakeData'
,
'
MapData'
,
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'MapData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'
DataFromQueue'
,
'
LocallyShuffleData'
]
'LocallyShuffleData'
]
class
BatchData
(
ProxyDataFlow
):
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
@@ -135,25 +135,6 @@ class RepeatedData(ProxyDataFlow):
...
@@ -135,25 +135,6 @@ class RepeatedData(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
class
FakeData
(
RNGDataFlow
):
""" Generate fake random data of given shapes"""
def
__init__
(
self
,
shapes
,
size
):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
self
.
_size
=
int
(
size
)
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
for
_
in
range
(
self
.
_size
):
yield
[
self
.
rng
.
random_sample
(
k
)
.
astype
(
'float32'
)
for
k
in
self
.
shapes
]
#yield [self.rng.random_sample(k) for k in self.shapes]
class
MapData
(
ProxyDataFlow
):
class
MapData
(
ProxyDataFlow
):
""" Apply map/filter a function on the datapoint"""
""" Apply map/filter a function on the datapoint"""
def
__init__
(
self
,
ds
,
func
):
def
__init__
(
self
,
ds
,
func
):
...
@@ -323,24 +304,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
...
@@ -323,24 +304,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def
__init__
(
self
,
ds
,
cache_size
):
def
__init__
(
self
,
ds
,
cache_size
):
ProxyDataFlow
.
__init__
(
self
,
ds
)
ProxyDataFlow
.
__init__
(
self
,
ds
)
RNGDataFlow
.
__init__
(
self
)
RNGDataFlow
.
__init__
(
self
)
self
.
cache_size
=
cache_size
self
.
q
=
deque
(
maxlen
=
cache_size
)
self
.
q
=
deque
(
maxlen
=
self
.
cache_size
)
self
.
ds_wrap
=
RepeatedData
(
ds
,
-
1
)
self
.
ds_itr
=
self
.
ds_wrap
.
get_data
()
self
.
current_cnt
=
0
def
reset_state
(
self
):
def
reset_state
(
self
):
ProxyDataFlow
.
reset_state
(
self
)
RNGDataFlow
.
reset_state
(
self
)
RNGDataFlow
.
reset_state
(
self
)
self
.
ds_wrap
=
RepeatedData
(
self
.
ds
,
-
1
)
self
.
ds_itr
=
self
.
ds_wrap
.
get_data
()
self
.
current_cnt
=
0
def
get_data
(
self
):
def
get_data
(
self
):
# TODO
for
_
in
range
(
self
.
q
.
maxlen
-
len
(
self
.
q
)):
pass
self
.
q
.
append
(
next
(
self
.
ds_itr
))
cnt
=
0
class
DataFromQueue
(
DataFlow
):
""" Provide data from a queue """
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
def
get_data
(
self
):
while
True
:
while
True
:
yield
self
.
queue
.
get
()
self
.
rng
.
shuffle
(
self
.
q
)
for
_
in
range
(
self
.
q
.
maxlen
):
yield
self
.
q
.
popleft
()
cnt
+=
1
if
cnt
==
self
.
size
():
return
self
.
q
.
append
(
next
(
self
.
ds_itr
))
def
SelectComponent
(
ds
,
idxs
):
def
SelectComponent
(
ds
,
idxs
):
"""
"""
...
...
tensorpack/dataflow/raw.py
0 → 100644
View file @
6f9f4cd9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: raw.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
class
FakeData
(
RNGDataFlow
):
""" Generate fake random data of given shapes"""
def
__init__
(
self
,
shapes
,
size
):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
self
.
_size
=
int
(
size
)
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
for
_
in
range
(
self
.
_size
):
yield
[
self
.
rng
.
random_sample
(
k
)
.
astype
(
'float32'
)
for
k
in
self
.
shapes
]
#yield [self.rng.random_sample(k) for k in self.shapes]
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
def
get_data
(
self
):
while
True
:
yield
self
.
queue
.
get
()
class
DataFromList
(
RNGDataFlow
):
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
super
(
DataFromList
,
self
)
.
__init__
()
self
.
lst
=
lst
self
.
shuffle
=
shuffle
def
size
(
self
):
return
len
(
self
.
lst
)
def
get_data
(
self
):
if
not
self
.
shuffle
:
for
k
in
self
.
lst
:
yield
k
else
:
idxs
=
self
.
rng
.
shuffle
(
np
.
arange
(
len
(
self
.
lst
)))
for
k
in
idxs
:
yield
self
.
lst
[
k
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment