Commit 2b2e12ed authored by Yuxin Wu's avatar Yuxin Wu

minor update

parent 552ec9e9
...@@ -130,8 +130,8 @@ def get_data(train_or_test, isMixup, alpha): ...@@ -130,8 +130,8 @@ def get_data(train_or_test, isMixup, alpha):
batch = BATCH_SIZE batch = BATCH_SIZE
ds = BatchData(ds, batch, remainder=not isTrain) ds = BatchData(ds, batch, remainder=not isTrain)
def f(ds): def f(dp):
images, labels = ds images, labels = dp
one_hot_labels = np.eye(CLASS_NUM)[labels] # one hot coding one_hot_labels = np.eye(CLASS_NUM)[labels] # one hot coding
if not isTrain or not isMixup: if not isTrain or not isMixup:
return [images, one_hot_labels] return [images, one_hot_labels]
......
...@@ -97,6 +97,7 @@ class _MultiProcessZMQDataFlow(DataFlow): ...@@ -97,6 +97,7 @@ class _MultiProcessZMQDataFlow(DataFlow):
if not self._reset_done: if not self._reset_done:
return return
if not self.context.closed: if not self.context.closed:
self.socket.close(0)
self.context.destroy(0) self.context.destroy(0)
for x in self._procs: for x in self._procs:
x.terminate() x.terminate()
...@@ -239,6 +240,9 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -239,6 +240,9 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
# sigint could still propagate here, e.g. when nested # sigint could still propagate here, e.g. when nested
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally:
socket.close(0)
context.destroy(0)
def __init__(self, ds, nr_proc=1, hwm=50): def __init__(self, ds, nr_proc=1, hwm=50):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment