Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5476b488
Commit
5476b488
authored
Apr 19, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
map/filter dataflow
parent
174c3fc9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
9 deletions
+15
-9
examples/ResNet/svhn_resnet.py
examples/ResNet/svhn_resnet.py
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+14
-8
No files found.
examples/ResNet/svhn_resnet.py
View file @
5476b488
...
@@ -171,7 +171,7 @@ def get_config():
...
@@ -171,7 +171,7 @@ def get_config():
InferenceRunner
(
dataset_test
,
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()
]),
[
ScalarStats
(
'cost'
),
ClassificationError
()
]),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
1
,
0.1
),
(
20
,
0.01
),
(
33
,
0.001
),
(
6
0
,
0.0001
)])
[(
1
,
0.1
),
(
20
,
0.01
),
(
28
,
0.001
),
(
5
0
,
0.0001
)])
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(
n
=
18
),
model
=
Model
(
n
=
18
),
...
...
tensorpack/dataflow/common.py
View file @
5476b488
...
@@ -144,25 +144,29 @@ class FakeData(DataFlow):
...
@@ -144,25 +144,29 @@ class FakeData(DataFlow):
yield
[
self
.
rng
.
random_sample
(
k
)
for
k
in
self
.
shapes
]
yield
[
self
.
rng
.
random_sample
(
k
)
for
k
in
self
.
shapes
]
class
MapData
(
ProxyDataFlow
):
class
MapData
(
ProxyDataFlow
):
"""
Map
a function on the datapoint"""
"""
Apply map/filter
a function on the datapoint"""
def
__init__
(
self
,
ds
,
func
):
def
__init__
(
self
,
ds
,
func
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new datapoint
:param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point.
"""
"""
super
(
MapData
,
self
)
.
__init__
(
ds
)
super
(
MapData
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
self
.
func
=
func
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
self
.
func
(
dp
)
ret
=
self
.
func
(
dp
)
if
ret
is
not
None
:
yield
ret
class
MapDataComponent
(
ProxyDataFlow
):
class
MapDataComponent
(
ProxyDataFlow
):
""" Apply
a function to
the given index in the datapoint"""
""" Apply
map/filter on
the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a new value of dp[index]
:param func: a function that takes a datapoint dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
"""
"""
super
(
MapDataComponent
,
self
)
.
__init__
(
ds
)
super
(
MapDataComponent
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
self
.
func
=
func
...
@@ -170,9 +174,11 @@ class MapDataComponent(ProxyDataFlow):
...
@@ -170,9 +174,11 @@ class MapDataComponent(ProxyDataFlow):
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
dp
=
copy
.
deepcopy
(
dp
)
# avoid modifying the original dp
repl
=
self
.
func
(
dp
[
self
.
index
])
dp
[
self
.
index
]
=
self
.
func
(
dp
[
self
.
index
])
if
repl
is
not
None
:
yield
dp
dp
=
copy
.
deepcopy
(
dp
)
# avoid modifying the original dp
dp
[
self
.
index
]
=
repl
yield
dp
class
RandomChooseData
(
DataFlow
):
class
RandomChooseData
(
DataFlow
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment