Commit 0ba0336c authored by Yuxin Wu's avatar Yuxin Wu

prefetch with multiprocessing

parent f7af025e
...@@ -20,7 +20,7 @@ from tensorpack.dataflow import imgaug ...@@ -20,7 +20,7 @@ from tensorpack.dataflow import imgaug
""" """
This config follows the same preprocessing/model/hyperparemeters as in This config follows the same preprocessing/model/hyperparemeters as in
tensorflow cifar10 examples. (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/) tensorflow cifar10 examples. (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/)
But it's faster. 86% accuracy. faster.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
......
...@@ -63,18 +63,26 @@ class FixedSizeData(DataFlow): ...@@ -63,18 +63,26 @@ class FixedSizeData(DataFlow):
def __init__(self, ds, size): def __init__(self, ds, size):
self.ds = ds self.ds = ds
self._size = size self._size = size
self.itr = None
def size(self): def size(self):
return self._size return self._size
def get_data(self): def get_data(self):
if self.itr is None:
self.itr = self.ds.get_data()
cnt = 0 cnt = 0
while True: while True:
for dp in self.ds.get_data(): try:
cnt += 1 dp = self.itr.next()
yield dp except StopIteration:
if cnt == self._size: self.itr = self.ds.get_data()
return dp = self.itr.next()
cnt += 1
yield dp
if cnt == self._size:
return
class RepeatedData(DataFlow): class RepeatedData(DataFlow):
""" repeat another dataflow for certain times""" """ repeat another dataflow for certain times"""
...@@ -93,6 +101,9 @@ class RepeatedData(DataFlow): ...@@ -93,6 +101,9 @@ class RepeatedData(DataFlow):
class FakeData(DataFlow): class FakeData(DataFlow):
""" Build fake random data of given shapes""" """ Build fake random data of given shapes"""
def __init__(self, shapes, size): def __init__(self, shapes, size):
"""
shapes: list of list/tuple
"""
self.shapes = shapes self.shapes = shapes
self._size = size self._size = size
...@@ -126,6 +137,7 @@ def AugmentImageComponent(ds, augmentors, index=0): ...@@ -126,6 +137,7 @@ def AugmentImageComponent(ds, augmentors, index=0):
augmentors: a list of ImageAugmentor instance augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0 index: the index of image in each data point. default to be 0
""" """
# TODO reset rng at the beginning of each get_data
aug = AugmentorList(augmentors) aug = AugmentorList(augmentors)
return MapData( return MapData(
ds, ds,
......
...@@ -13,40 +13,48 @@ class Sentinel: ...@@ -13,40 +13,48 @@ class Sentinel:
pass pass
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue_size): def __init__(self, ds, queue):
"""
ds: ds to take data from
queue: output queue to put results in
"""
super(PrefetchProcess, self).__init__() super(PrefetchProcess, self).__init__()
self.ds = ds self.ds = ds
self.queue = multiprocessing.Queue(queue_size) self.queue = queue
def run(self): def run(self):
for dp in self.ds.get_data(): try:
self.queue.put(dp) for dp in self.ds.get_data():
self.queue.put(Sentinel()) self.queue.put(dp)
finally:
self.queue.put(Sentinel())
def get_data(self):
while True:
ret = self.queue.get()
if isinstance(ret, Sentinel):
return
yield ret
class PrefetchData(DataFlow): class PrefetchData(DataFlow):
def __init__(self, ds, nr_prefetch): def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
use multiprocess, will duplicate ds by nr_proc times
"""
self.ds = ds self.ds = ds
self.nr_prefetch = int(nr_prefetch) self.nr_proc = nr_proc
assert self.nr_prefetch > 0 self.nr_prefetch = nr_prefetch
def size(self):
return self.ds.size()
def get_data(self): def get_data(self):
worker = PrefetchProcess(self.ds, self.nr_prefetch) queue = multiprocessing.Queue(self.nr_prefetch)
# TODO register terminate function procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
worker.start() [x.start() for x in procs]
end_cnt = 0
try: try:
for dp in worker.get_data(): while True:
dp = queue.get()
if isinstance(dp, Sentinel):
end_cnt += 1
if end_cnt == self.nr_proc:
break
continue
yield dp yield dp
finally: finally:
worker.join() queue.close()
worker.terminate() [x.terminate() for x in procs]
...@@ -11,10 +11,11 @@ __all__ = ['BatchNorm'] ...@@ -11,10 +11,11 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow # http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# Only work for 4D tensor right now: #804
@layer_register() @layer_register()
def BatchNorm(x, is_training): def BatchNorm(x, is_training):
""" """
x: has to be BHWC for now x: BHWC tensor
is_training: bool is_training: bool
""" """
is_training = bool(is_training) is_training = bool(is_training)
......
...@@ -11,7 +11,7 @@ __all__ = ['Conv2D'] ...@@ -11,7 +11,7 @@ __all__ = ['Conv2D']
@layer_register(summary_activation=True) @layer_register(summary_activation=True)
def Conv2D(x, out_channel, kernel_shape, def Conv2D(x, out_channel, kernel_shape,
padding='VALID', stride=1, padding='SAME', stride=1,
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.nn.relu, split=1): nl=tf.nn.relu, split=1):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment