Commit 064ea7c7 authored by Yuxin Wu's avatar Yuxin Wu

cleaner scope logic in TowerContext

parent 844d8e69
......@@ -8,7 +8,11 @@ from abc import abstractmethod, ABCMeta
import six
from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class DataFlowTerminated(BaseException):
pass
@six.add_metaclass(ABCMeta)
......
......@@ -6,11 +6,12 @@ from __future__ import print_function
import multiprocessing as mp
import itertools
from six.moves import range, zip, queue
import errno
import uuid
import os
import zmq
from .base import ProxyDataFlow
from .base import ProxyDataFlow, DataFlowTerminated
from .common import RepeatedData
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
......@@ -154,8 +155,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
dp = loads(self.socket.recv(copy=False).bytes)
yield dp
except zmq.ContextTerminated:
logger.info("ContextTerminated in Master Prefetch Process")
return
logger.info("[Prefetch Master] Context terminated.")
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated()
else:
raise
except:
raise
......
......@@ -13,7 +13,7 @@ from itertools import chain
from six.moves import range, zip
from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData
from ..dataflow import DataFlow, RepeatedData, DataFlowTerminated
from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
......@@ -186,7 +186,7 @@ class EnqueueThread(ShareSessionThread):
feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError):
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
......
......@@ -19,15 +19,15 @@ class TowerContext(object):
Args:
tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower.
index (int): index of this tower, only used in training.
vs_name (str): Open a variable scope with this name, if given.
"""
self._name = tower_name
self._is_training = bool(is_training)
if not self._is_training:
# TODO ugly
assert index == 0 and vs_name == '', "vs_name and index are meaningless in prediction!"
assert index == 0 and vs_name == '', \
"vs_name and index are only used in prediction!"
self._index = int(index)
self._vs_name = str(vs_name)
......@@ -85,29 +85,32 @@ class TowerContext(object):
"Nesting TowerContext!"
_CurrentTowerContext = self
self._ctxs = []
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Nesting TowerContext with an existing variable scope!"
# assert empty name scope as well (>1.2.1?)
if len(self._name):
if not self.is_training:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name))
else:
if self.has_own_variables:
if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name))
else:
if self.is_training:
self._ctxs.append(tf.name_scope(self._name))
else:
reuse = self._index > 0
if reuse is True:
# clear old name_scope (due to the existing variable_scope)
# and re-enter the current variable_scope
self._ctxs.append(tf.name_scope(None))
if reuse:
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
else:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name))
for c in self._ctxs:
c.__enter__()
# currently only check for predictor towers
if not self.is_training and get_tf_version_number() >= 1.2:
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment