Commit 48a19f6d authored by Yuxin Wu's avatar Yuxin Wu

fix exiting behavior

parent 8bdd9c85
...@@ -97,7 +97,7 @@ def get_data(train_or_test): ...@@ -97,7 +97,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 5) ds = PrefetchData(ds, 10, 5)
return ds return ds
def get_config(): def get_config():
...@@ -127,7 +127,7 @@ def get_config(): ...@@ -127,7 +127,7 @@ def get_config():
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=300, max_epoch=250,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -67,14 +67,6 @@ class PrefetchData(ProxyDataFlow): ...@@ -67,14 +67,6 @@ class PrefetchData(ProxyDataFlow):
dp = self.queue.get() dp = self.queue.get()
yield dp yield dp
def __del__(self):
logger.info("Prefetch process exiting...")
self.queue.close()
for x in self.procs:
x.terminate()
logger.info("Prefetch process exited.")
class PrefetchProcessZMQ(multiprocessing.Process): class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name): def __init__(self, ds, conn_name):
""" """
...@@ -118,6 +110,10 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -118,6 +110,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
for x in self.procs: for x in self.procs:
x.start() x.start()
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
def get_data(self): def get_data(self):
for _ in range(self._size): for _ in range(self._size):
dp = loads(self.socket.recv(copy=False)) dp = loads(self.socket.recv(copy=False))
...@@ -125,7 +121,8 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -125,7 +121,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
def __del__(self): def __del__(self):
logger.info("Prefetch process exiting...") logger.info("Prefetch process exiting...")
self.context.destroy(0) if not self.context.closed:
self.context.destroy(0)
for x in self.procs: for x in self.procs:
x.terminate() x.terminate()
logger.info("Prefetch process exited.") logger.info("Prefetch process exited.")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment