Commit 6bdd0460 authored by Yuxin Wu's avatar Yuxin Wu

add plasma store

parent 1f07de76
...@@ -186,6 +186,7 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -186,6 +186,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
PrefetchData = MultiProcessPrefetchData PrefetchData = MultiProcessPrefetchData
# TODO renamed to MultiProcessDataFlow{,ZMQ} if separated to a new project
class PrefetchDataZMQ(_MultiProcessZMQDataFlow): class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
""" """
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for Prefetch data from a DataFlow using multiple processes, with ZeroMQ for
...@@ -379,3 +380,46 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -379,3 +380,46 @@ class MultiThreadPrefetchData(DataFlow):
for p in self.threads: for p in self.threads:
p.stop() p.stop()
p.join() p.join()
class PlasmaPutData(ProxyDataFlow):
"""
Put each data point to plasma shared memory object store, and yield the object id instead.
"""
def __init__(self, ds):
super(PlasmaPutData, self).__init__(ds)
def reset_state(self):
super(PlasmaPutData, self).reset_state()
self.client = plasma.connect("/tmp/plasma", "", 0)
def get_data(self):
for dp in self.ds.get_data():
oid = self.client.put(dp)
yield [oid.binary()]
class PlasmaGetData(ProxyDataFlow):
"""
Take plasma object id from a DataFlow, and retrieve it from plasma shared
memory object store.
"""
def __init__(self, ds):
super(PlasmaGetData, self).__init__(ds)
def reset_state(self):
super(PlasmaGetData, self).reset_state()
self.client = plasma.connect("/tmp/plasma", "", 0)
def get_data(self):
for dp in self.ds.get_data():
oid = plasma.ObjectID(dp[0])
dp = self.client.get(oid)
yield dp
try:
import pyarrow.plasma as plasma
except ImportError:
PlasmaPutData = create_dummy_class('PlasmaPutData', 'pyarrow') # noqa
PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
...@@ -181,7 +181,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -181,7 +181,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
@layer_register() @layer_register()
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, data_format='NHWC'): use_scale=True, use_bias=True, gamma_init=None, data_format='NHWC'):
""" """
Batch Renormalization layer, as described in the paper: Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
...@@ -230,6 +230,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -230,6 +230,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
'rmax': rmax, 'rmax': rmax,
'dmax': dmax}, 'dmax': dmax},
renorm_momentum=0.99, renorm_momentum=0.99,
gamma_initializer=gamma_init,
fused=False) fused=False)
xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope()) xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment