Commit c326e840 authored by Yuxin Wu's avatar Yuxin Wu

better name_scope in InputSource (#340)

parent cf2012dd
...@@ -175,7 +175,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -175,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size() self.size_op = self.queue.size()
add_moving_summary(tf.cast( add_moving_summary(tf.cast(
self.size_op, tf.float32, name='input_queue_size')) self.size_op, tf.float32, name='queue_size'))
def run(self): def run(self):
with self.default_sess(): with self.default_sess():
...@@ -223,6 +223,8 @@ class QueueInput(FeedfreeInput): ...@@ -223,6 +223,8 @@ class QueueInput(FeedfreeInput):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \ assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!" "QueueInput has to be used with some inputs!"
with tf.name_scope('QueueInput') as ns:
self._name_scope = ns
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self._input_placehdrs], 50, [x.dtype for x in self._input_placehdrs],
...@@ -236,7 +238,7 @@ class QueueInput(FeedfreeInput): ...@@ -236,7 +238,7 @@ class QueueInput(FeedfreeInput):
return [cb] return [cb]
def _get_input_tensors(self): def _get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'), tf.name_scope(self._name_scope):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
...@@ -287,6 +289,8 @@ class BatchQueueInput(QueueInput): ...@@ -287,6 +289,8 @@ class BatchQueueInput(QueueInput):
assert p.get_shape().is_fully_defined(), shape_err assert p.get_shape().is_fully_defined(), shape_err
shapes.append(p.get_shape()) shapes.append(p.get_shape())
with tf.name_scope('BatchQueueInput') as ns:
self._name_scope = ns
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
3000, [x.dtype for x in self.input_placehdrs], 3000, [x.dtype for x in self.input_placehdrs],
...@@ -298,7 +302,7 @@ class BatchQueueInput(QueueInput): ...@@ -298,7 +302,7 @@ class BatchQueueInput(QueueInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def _get_input_tensors(self): def _get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'), tf.name_scope(self._name_scope):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque') ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
...@@ -446,6 +450,8 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -446,6 +450,8 @@ class StagingInputWrapper(FeedfreeInput):
def _setup_staging_areas(self): def _setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...") logger.info("Setting up StagingArea for GPU prefetching ...")
with tf.name_scope('StagingInputWrapper') as ns:
self._name_scope = ns
for idx, device in enumerate(self._devices): for idx, device in enumerate(self._devices):
with tf.device(device): with tf.device(device):
inputs = self._input.get_input_tensors() inputs = self._input.get_input_tensors()
...@@ -469,8 +475,10 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -469,8 +475,10 @@ class StagingInputWrapper(FeedfreeInput):
return ret return ret
def _get_stage_op(self): def _get_stage_op(self):
with tf.name_scope(self._name_scope):
return tf.group(*self._stage_ops) return tf.group(*self._stage_ops)
def _get_unstage_op(self): def _get_unstage_op(self):
with tf.name_scope(self._name_scope):
all_outputs = list(chain.from_iterable(self._unstage_ops)) all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs) return tf.group(*all_outputs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment