Commit c326e840 authored by Yuxin Wu's avatar Yuxin Wu

better name_scope in InputSource (#340)

parent cf2012dd
......@@ -175,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size'))
self.size_op, tf.float32, name='queue_size'))
def run(self):
with self.default_sess():
......@@ -223,11 +223,13 @@ class QueueInput(FeedfreeInput):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self._input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
with tf.name_scope('QueueInput') as ns:
self._name_scope = ns
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self._input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def _get_callbacks(self):
from ..callbacks.concurrency import StartProcOrThread
......@@ -236,7 +238,7 @@ class QueueInput(FeedfreeInput):
return [cb]
def _get_input_tensors(self):
with tf.device('/cpu:0'):
with tf.device('/cpu:0'), tf.name_scope(self._name_scope):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
......@@ -287,18 +289,20 @@ class BatchQueueInput(QueueInput):
assert p.get_shape().is_fully_defined(), shape_err
shapes.append(p.get_shape())
if self.queue is None:
self.queue = tf.FIFOQueue(
3000, [x.dtype for x in self.input_placehdrs],
shapes=shapes,
name='input_queue')
for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err
with tf.name_scope('BatchQueueInput') as ns:
self._name_scope = ns
if self.queue is None:
self.queue = tf.FIFOQueue(
3000, [x.dtype for x in self.input_placehdrs],
shapes=shapes,
name='input_queue')
for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def _get_input_tensors(self):
with tf.device('/cpu:0'):
with tf.device('/cpu:0'), tf.name_scope(self._name_scope):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
......@@ -446,19 +450,21 @@ class StagingInputWrapper(FeedfreeInput):
def _setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...")
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
dtypes = [x.dtype for x in inputs]
stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
with tf.name_scope('StagingInputWrapper') as ns:
self._name_scope = ns
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
dtypes = [x.dtype for x in inputs]
stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
def _size(self):
return self._input.size()
......@@ -469,8 +475,10 @@ class StagingInputWrapper(FeedfreeInput):
return ret
def _get_stage_op(self):
return tf.group(*self._stage_ops)
with tf.name_scope(self._name_scope):
return tf.group(*self._stage_ops)
def _get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
with tf.name_scope(self._name_scope):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment