Commit 5476b488 authored by Yuxin Wu's avatar Yuxin Wu

map/filter dataflow

parent 174c3fc9
...@@ -171,7 +171,7 @@ def get_config(): ...@@ -171,7 +171,7 @@ def get_config():
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ]), [ScalarStats('cost'), ClassificationError() ]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)]) [(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(n=18), model=Model(n=18),
......
...@@ -144,25 +144,29 @@ class FakeData(DataFlow): ...@@ -144,25 +144,29 @@ class FakeData(DataFlow):
yield [self.rng.random_sample(k) for k in self.shapes] yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Map a function on the datapoint""" """ Apply map/filter a function on the datapoint"""
def __init__(self, ds, func): def __init__(self, ds, func):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new datapoint :param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point.
""" """
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield self.func(dp) ret = self.func(dp)
if ret is not None:
yield ret
class MapDataComponent(ProxyDataFlow): class MapDataComponent(ProxyDataFlow):
""" Apply a function to the given index in the datapoint""" """ Apply map/filter on the given index in the datapoint"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a new value of dp[index] :param func: a function that takes a datapoint dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
""" """
super(MapDataComponent, self).__init__(ds) super(MapDataComponent, self).__init__(ds)
self.func = func self.func = func
...@@ -170,8 +174,10 @@ class MapDataComponent(ProxyDataFlow): ...@@ -170,8 +174,10 @@ class MapDataComponent(ProxyDataFlow):
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
repl = self.func(dp[self.index])
if repl is not None:
dp = copy.deepcopy(dp) # avoid modifying the original dp dp = copy.deepcopy(dp) # avoid modifying the original dp
dp[self.index] = self.func(dp[self.index]) dp[self.index] = repl
yield dp yield dp
class RandomChooseData(DataFlow): class RandomChooseData(DataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment