Commit bda9e14e authored by Yuxin Wu's avatar Yuxin Wu

Reset QueueInput (fix #487)

parent f2ca6b1a
...@@ -157,10 +157,11 @@ class Affine(TransformAugmentorBase): ...@@ -157,10 +157,11 @@ class Affine(TransformAugmentorBase):
""" """
Random affine transform of the image w.r.t to the image center. Random affine transform of the image w.r.t to the image center.
Transformations involve: Transformations involve:
- Translation ("move" image on the x-/y-axis)
- Rotation - Translation ("move" image on the x-/y-axis)
- Scaling ("zoom" in/out) - Rotation
- Shear (move one side of the image, turning a square into a trapezoid) - Scaling ("zoom" in/out)
- Shear (move one side of the image, turning a square into a trapezoid)
""" """
def __init__(self, scale=None, translate_frac=None, rotate_max_deg=0.0, shear=0.0, def __init__(self, scale=None, translate_frac=None, rotate_max_deg=0.0, shear=0.0,
......
...@@ -11,6 +11,7 @@ except ImportError: ...@@ -11,6 +11,7 @@ except ImportError:
from itertools import chain from itertools import chain
from six.moves import range, zip from six.moves import range, zip
import threading
from .input_source_base import InputSource from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData, DataFlowTerminated from ..dataflow import DataFlow, RepeatedData, DataFlowTerminated
...@@ -116,7 +117,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -116,7 +117,7 @@ class EnqueueThread(ShareSessionThread):
self.name = 'EnqueueThread ' + queue.name self.name = 'EnqueueThread ' + queue.name
self.daemon = True self.daemon = True
self.dataflow = ds self.dataflow = RepeatedData(ds, -1)
self.queue = queue self.queue = queue
self.placehdrs = placehdrs self.placehdrs = placehdrs
...@@ -124,15 +125,20 @@ class EnqueueThread(ShareSessionThread): ...@@ -124,15 +125,20 @@ class EnqueueThread(ShareSessionThread):
self.op = self.queue.enqueue(self.placehdrs) self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self._lock = threading.Lock()
def run(self): def run(self):
with self.default_sess(): with self.default_sess():
try: try:
self.dataflow.reset_state() self.reset_dataflow()
while True: while True:
for dp in self.dataflow.get_data(): # pausable loop
feed = dict(zip(self.placehdrs, dp)) self._lock.acquire()
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] self._lock.release()
self.op.run(feed_dict=feed)
dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp))
self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass pass
except Exception as e: except Exception as e:
...@@ -147,10 +153,22 @@ class EnqueueThread(ShareSessionThread): ...@@ -147,10 +153,22 @@ class EnqueueThread(ShareSessionThread):
pass pass
logger.info("{} Exited.".format(self.name)) logger.info("{} Exited.".format(self.name))
def reset_dataflow(self):
self.dataflow.reset_state()
self._itr = self.dataflow.get_data()
def pause(self):
self._lock.acquire()
def resume(self):
self._lock.release()
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
""" Enqueue datapoints from a DataFlow to a TF queue. """ Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors. And the model receives dequeued tensors.
Calling :meth:`reset_state()` will clear the queue and reset the dataflow.
""" """
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None):
...@@ -180,6 +198,25 @@ class QueueInput(FeedfreeInput): ...@@ -180,6 +198,25 @@ class QueueInput(FeedfreeInput):
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name)) logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')
def _reset_state(self):
self.thread.pause() # pause enqueue
opt = tf.RunOptions()
opt.timeout_in_ms = 2000 # 2s
sess = tf.get_default_session()
# dequeue until empty
try:
while True:
sess.run(self._dequeue_op, options=opt)
except tf.errors.DeadlineExceededError:
pass
# reset dataflow, start thread
self.thread.reset_dataflow()
self.thread.resume()
def _create_ema_callback(self): def _create_ema_callback(self):
""" """
Create a hook-only callback which maintain EMA of the queue size. Create a hook-only callback which maintain EMA of the queue size.
......
...@@ -128,6 +128,7 @@ class InputSource(object): ...@@ -128,6 +128,7 @@ class InputSource(object):
def reset_state(self): def reset_state(self):
""" """
Initialize/reinitialize this InputSource. Initialize/reinitialize this InputSource.
Must be called under a default session.
For training, it will get called by the trainer in `before_train` callbacks. For training, it will get called by the trainer in `before_train` callbacks.
For inference, the :class:`InferenceRunner` will call this method each time it is triggered. For inference, the :class:`InferenceRunner` will call this method each time it is triggered.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment