Commit 4f352b29 authored by Yuxin Wu's avatar Yuxin Wu

pipedir in prefetchzmq

parent ef117f9d
...@@ -75,7 +75,7 @@ class AugmentorList(ImageAugmentor): ...@@ -75,7 +75,7 @@ class AugmentorList(ImageAugmentor):
super(AugmentorList, self).__init__() super(AugmentorList, self).__init__()
def _get_augment_params(self, img): def _get_augment_params(self, img):
# the next augmentor requires the previos one to finish # the next augmentor requires the previous one to finish
raise RuntimeError("Cannot simply get parameters of a AugmentorList!") raise RuntimeError("Cannot simply get parameters of a AugmentorList!")
def _augment_return_params(self, img): def _augment_return_params(self, img):
......
...@@ -95,11 +95,12 @@ class PrefetchProcessZMQ(multiprocessing.Process): ...@@ -95,11 +95,12 @@ class PrefetchProcessZMQ(multiprocessing.Process):
class PrefetchDataZMQ(ProxyDataFlow): class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """ """ Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_proc=1): def __init__(self, ds, nr_proc=1, pipedir='.'):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order :param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random. of datapoints will be random.
:param pipedir: a local directory where the pipes would be. Useful if you're running on non-local FS such as NFS.
""" """
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
try: try:
...@@ -110,7 +111,8 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -110,7 +111,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
self.pipename = "ipc://dataflow-pipe-" + str(uuid.uuid1())[:6] assert os.path.isdir(pipedir)
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(5) # a little bit faster than default, don't know why self.socket.set_hwm(5) # a little bit faster than default, don't know why
self.socket.bind(self.pipename) self.socket.bind(self.pipename)
...@@ -130,9 +132,16 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -130,9 +132,16 @@ class PrefetchDataZMQ(ProxyDataFlow):
yield dp yield dp
def __del__(self): def __del__(self):
logger.info("Prefetch process exiting...") # on exit, logger may not be functional anymore
try:
logger.info("Prefetch process exiting...")
except:
pass
if not self.context.closed: if not self.context.closed:
self.context.destroy(0) self.context.destroy(0)
for x in self.procs: for x in self.procs:
x.terminate() x.terminate()
logger.info("Prefetch process exited.") try:
logger.info("Prefetch process exited.")
except:
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment