Commit 0ba0336c authored by Yuxin Wu's avatar Yuxin Wu

prefetch with multiprocessing

parent f7af025e
......@@ -20,7 +20,7 @@ from tensorpack.dataflow import imgaug
"""
This config follows the same preprocessing/model/hyperparemeters as in
tensorflow cifar10 examples. (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/)
But it's faster.
86% accuracy. faster.
"""
BATCH_SIZE = 128
......
......@@ -63,18 +63,26 @@ class FixedSizeData(DataFlow):
def __init__(self, ds, size):
self.ds = ds
self._size = size
self.itr = None
def size(self):
return self._size
def get_data(self):
if self.itr is None:
self.itr = self.ds.get_data()
cnt = 0
while True:
for dp in self.ds.get_data():
cnt += 1
yield dp
if cnt == self._size:
return
try:
dp = self.itr.next()
except StopIteration:
self.itr = self.ds.get_data()
dp = self.itr.next()
cnt += 1
yield dp
if cnt == self._size:
return
class RepeatedData(DataFlow):
""" repeat another dataflow for certain times"""
......@@ -93,6 +101,9 @@ class RepeatedData(DataFlow):
class FakeData(DataFlow):
""" Build fake random data of given shapes"""
def __init__(self, shapes, size):
"""
shapes: list of list/tuple
"""
self.shapes = shapes
self._size = size
......@@ -126,6 +137,7 @@ def AugmentImageComponent(ds, augmentors, index=0):
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
# TODO reset rng at the beginning of each get_data
aug = AugmentorList(augmentors)
return MapData(
ds,
......
......@@ -13,40 +13,48 @@ class Sentinel:
pass
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue_size):
def __init__(self, ds, queue):
"""
ds: ds to take data from
queue: output queue to put results in
"""
super(PrefetchProcess, self).__init__()
self.ds = ds
self.queue = multiprocessing.Queue(queue_size)
self.queue = queue
def run(self):
for dp in self.ds.get_data():
self.queue.put(dp)
self.queue.put(Sentinel())
try:
for dp in self.ds.get_data():
self.queue.put(dp)
finally:
self.queue.put(Sentinel())
def get_data(self):
while True:
ret = self.queue.get()
if isinstance(ret, Sentinel):
return
yield ret
class PrefetchData(DataFlow):
def __init__(self, ds, nr_prefetch):
def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
use multiprocess, will duplicate ds by nr_proc times
"""
self.ds = ds
self.nr_prefetch = int(nr_prefetch)
assert self.nr_prefetch > 0
def size(self):
return self.ds.size()
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
def get_data(self):
worker = PrefetchProcess(self.ds, self.nr_prefetch)
# TODO register terminate function
worker.start()
queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
[x.start() for x in procs]
end_cnt = 0
try:
for dp in worker.get_data():
while True:
dp = queue.get()
if isinstance(dp, Sentinel):
end_cnt += 1
if end_cnt == self.nr_proc:
break
continue
yield dp
finally:
worker.join()
worker.terminate()
queue.close()
[x.terminate() for x in procs]
......@@ -11,10 +11,11 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# Only work for 4D tensor right now: #804
@layer_register()
def BatchNorm(x, is_training):
"""
x: has to be BHWC for now
x: BHWC tensor
is_training: bool
"""
is_training = bool(is_training)
......
......@@ -11,7 +11,7 @@ __all__ = ['Conv2D']
@layer_register(summary_activation=True)
def Conv2D(x, out_channel, kernel_shape,
padding='VALID', stride=1,
padding='SAME', stride=1,
W_init=None, b_init=None,
nl=tf.nn.relu, split=1):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment