Commit b506eb0a authored by Yuxin Wu's avatar Yuxin Wu

fix prefetch

parent 477c4446
......@@ -9,24 +9,24 @@ import multiprocessing
__all__ = ['PrefetchData']
class Sentinel:
pass
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue_size):
super(PrefetchProcess, self).__init__()
self.ds = ds
self.queue = multiprocessing.Queue(queue_size)
class Sentinel:
pass
self.sentinel = Sentinel()
def run(self):
for dp in self.ds.get_data():
self.queue.put(dp)
self.queue.put(self.sentinel)
self.queue.put(Sentinel())
def get_data(self):
while True:
ret = self.queue.get()
if ret is self.sentinel:
if isinstance(ret, Sentinel):
return
yield ret
......@@ -43,8 +43,10 @@ class PrefetchData(DataFlow):
worker = PrefetchProcess(self.ds, self.nr_prefetch)
# TODO register terminate function
worker.start()
for dp in worker.get_data():
yield dp
worker.join()
worker.terminate()
try:
for dp in worker.get_data():
yield dp
finally:
worker.join()
worker.terminate()
......@@ -71,10 +71,10 @@ def ImageSample(template, mapping):
tf.reduce_max(diff), diff],
summarize=50)
return sample(template, lcoor) * neg_diffx * neg_diffy + \
sample(template, ucoor) * diffx * diffy + \
sample(template, lyux) * neg_diffy * diffx + \
sample(template, uylx) * diffy * neg_diffx
return tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled')
from _test import TestModel
class TestSample(TestModel):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment