Commit bda9e14e authored by Yuxin Wu's avatar Yuxin Wu

Reset QueueInput (fix #487)

parent f2ca6b1a
......@@ -157,10 +157,11 @@ class Affine(TransformAugmentorBase):
"""
Random affine transform of the image w.r.t to the image center.
Transformations involve:
- Translation ("move" image on the x-/y-axis)
- Rotation
- Scaling ("zoom" in/out)
- Shear (move one side of the image, turning a square into a trapezoid)
- Translation ("move" image on the x-/y-axis)
- Rotation
- Scaling ("zoom" in/out)
- Shear (move one side of the image, turning a square into a trapezoid)
"""
def __init__(self, scale=None, translate_frac=None, rotate_max_deg=0.0, shear=0.0,
......
......@@ -11,6 +11,7 @@ except ImportError:
from itertools import chain
from six.moves import range, zip
import threading
from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData, DataFlowTerminated
......@@ -116,7 +117,7 @@ class EnqueueThread(ShareSessionThread):
self.name = 'EnqueueThread ' + queue.name
self.daemon = True
self.dataflow = ds
self.dataflow = RepeatedData(ds, -1)
self.queue = queue
self.placehdrs = placehdrs
......@@ -124,15 +125,20 @@ class EnqueueThread(ShareSessionThread):
self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self._lock = threading.Lock()
def run(self):
with self.default_sess():
try:
self.dataflow.reset_state()
self.reset_dataflow()
while True:
for dp in self.dataflow.get_data():
feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
# pausable loop
self._lock.acquire()
self._lock.release()
dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp))
self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass
except Exception as e:
......@@ -147,10 +153,22 @@ class EnqueueThread(ShareSessionThread):
pass
logger.info("{} Exited.".format(self.name))
def reset_dataflow(self):
self.dataflow.reset_state()
self._itr = self.dataflow.get_data()
def pause(self):
self._lock.acquire()
def resume(self):
self._lock.release()
class QueueInput(FeedfreeInput):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors.
Calling :meth:`reset_state()` will clear the queue and reset the dataflow.
"""
def __init__(self, ds, queue=None):
......@@ -180,6 +198,25 @@ class QueueInput(FeedfreeInput):
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')
def _reset_state(self):
self.thread.pause() # pause enqueue
opt = tf.RunOptions()
opt.timeout_in_ms = 2000 # 2s
sess = tf.get_default_session()
# dequeue until empty
try:
while True:
sess.run(self._dequeue_op, options=opt)
except tf.errors.DeadlineExceededError:
pass
# reset dataflow, start thread
self.thread.reset_dataflow()
self.thread.resume()
def _create_ema_callback(self):
"""
Create a hook-only callback which maintain EMA of the queue size.
......
......@@ -128,6 +128,7 @@ class InputSource(object):
def reset_state(self):
"""
Initialize/reinitialize this InputSource.
Must be called under a default session.
For training, it will get called by the trainer in `before_train` callbacks.
For inference, the :class:`InferenceRunner` will call this method each time it is triggered.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment