Commit b506eb0a authored by Yuxin Wu's avatar Yuxin Wu

fix prefetch

parent 477c4446
...@@ -9,24 +9,24 @@ import multiprocessing ...@@ -9,24 +9,24 @@ import multiprocessing
__all__ = ['PrefetchData'] __all__ = ['PrefetchData']
class Sentinel:
pass
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue_size): def __init__(self, ds, queue_size):
super(PrefetchProcess, self).__init__() super(PrefetchProcess, self).__init__()
self.ds = ds self.ds = ds
self.queue = multiprocessing.Queue(queue_size) self.queue = multiprocessing.Queue(queue_size)
class Sentinel:
pass
self.sentinel = Sentinel()
def run(self): def run(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.queue.put(dp) self.queue.put(dp)
self.queue.put(self.sentinel) self.queue.put(Sentinel())
def get_data(self): def get_data(self):
while True: while True:
ret = self.queue.get() ret = self.queue.get()
if ret is self.sentinel: if isinstance(ret, Sentinel):
return return
yield ret yield ret
...@@ -43,8 +43,10 @@ class PrefetchData(DataFlow): ...@@ -43,8 +43,10 @@ class PrefetchData(DataFlow):
worker = PrefetchProcess(self.ds, self.nr_prefetch) worker = PrefetchProcess(self.ds, self.nr_prefetch)
# TODO register terminate function # TODO register terminate function
worker.start() worker.start()
for dp in worker.get_data(): try:
yield dp for dp in worker.get_data():
worker.join() yield dp
worker.terminate() finally:
worker.join()
worker.terminate()
...@@ -71,10 +71,10 @@ def ImageSample(template, mapping): ...@@ -71,10 +71,10 @@ def ImageSample(template, mapping):
tf.reduce_max(diff), diff], tf.reduce_max(diff), diff],
summarize=50) summarize=50)
return sample(template, lcoor) * neg_diffx * neg_diffy + \ return tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy + \ sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx + \ sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx sample(template, uylx) * diffy * neg_diffx], name='sampled')
from _test import TestModel from _test import TestModel
class TestSample(TestModel): class TestSample(TestModel):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment