Commit 064ea7c7 authored by Yuxin Wu's avatar Yuxin Wu

cleaner scope logic in TowerContext

parent 844d8e69
...@@ -8,7 +8,11 @@ from abc import abstractmethod, ABCMeta ...@@ -8,7 +8,11 @@ from abc import abstractmethod, ABCMeta
import six import six
from ..utils import get_rng from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class DataFlowTerminated(BaseException):
pass
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
......
...@@ -6,11 +6,12 @@ from __future__ import print_function ...@@ -6,11 +6,12 @@ from __future__ import print_function
import multiprocessing as mp import multiprocessing as mp
import itertools import itertools
from six.moves import range, zip, queue from six.moves import range, zip, queue
import errno
import uuid import uuid
import os import os
import zmq import zmq
from .base import ProxyDataFlow from .base import ProxyDataFlow, DataFlowTerminated
from .common import RepeatedData from .common import RepeatedData
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal, mask_sigint, start_proc_mask_signal,
...@@ -154,8 +155,14 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -154,8 +155,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
dp = loads(self.socket.recv(copy=False).bytes) dp = loads(self.socket.recv(copy=False).bytes)
yield dp yield dp
except zmq.ContextTerminated: except zmq.ContextTerminated:
logger.info("ContextTerminated in Master Prefetch Process") logger.info("[Prefetch Master] Context terminated.")
return raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated()
else:
raise
except: except:
raise raise
......
...@@ -13,7 +13,7 @@ from itertools import chain ...@@ -13,7 +13,7 @@ from itertools import chain
from six.moves import range, zip from six.moves import range, zip
from .input_source_base import InputSource from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData, DataFlowTerminated
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
...@@ -186,7 +186,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -186,7 +186,7 @@ class EnqueueThread(ShareSessionThread):
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] # print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError): except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass pass
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
......
...@@ -19,15 +19,15 @@ class TowerContext(object): ...@@ -19,15 +19,15 @@ class TowerContext(object):
Args: Args:
tower_name (str): The name scope of the tower. tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower. index (int): index of this tower, only used in training.
vs_name (str): Open a variable scope with this name, if given. vs_name (str): Open a variable scope with this name, if given.
""" """
self._name = tower_name self._name = tower_name
self._is_training = bool(is_training) self._is_training = bool(is_training)
if not self._is_training: if not self._is_training:
# TODO ugly assert index == 0 and vs_name == '', \
assert index == 0 and vs_name == '', "vs_name and index are meaningless in prediction!" "vs_name and index are only used in prediction!"
self._index = int(index) self._index = int(index)
self._vs_name = str(vs_name) self._vs_name = str(vs_name)
...@@ -85,29 +85,32 @@ class TowerContext(object): ...@@ -85,29 +85,32 @@ class TowerContext(object):
"Nesting TowerContext!" "Nesting TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
self._ctxs = [] self._ctxs = []
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Nesting TowerContext with an existing variable scope!"
# assert empty name scope as well (>1.2.1?)
if len(self._name): if len(self._name):
if not self.is_training:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name))
else:
if self.has_own_variables: if self.has_own_variables:
if len(self.vs_name): if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name)) self._ctxs.append(tf.variable_scope(self.vs_name))
else: else:
if self.is_training: self._ctxs.append(tf.name_scope(self._name))
else:
reuse = self._index > 0 reuse = self._index > 0
if reuse is True: if reuse:
# clear old name_scope (due to the existing variable_scope)
# and re-enter the current variable_scope
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope( self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True)) tf.get_variable_scope(), reuse=True))
else:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
for c in self._ctxs: for c in self._ctxs:
c.__enter__() c.__enter__()
# currently only check for predictor towers # currently only check for predictor towers
if not self.is_training and get_tf_version_number() >= 1.2: if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope() ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \ assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \ "Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment