Commit be4759be authored by Yuxin Wu's avatar Yuxin Wu

misc update

parent 58c2779f
......@@ -20,8 +20,7 @@ from .parallel import (
__all__ = ['ThreadedMapData', 'MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ',
'MultiProcessMapDataComponentSharedArray']
'MultiProcessMapData', 'MultiProcessMapDataZMQ']
class _ParallelMapData(ProxyDataFlow):
......@@ -302,6 +301,7 @@ def _pool_map(data):
return WORKER_ID
# TODO shutdown pool, improve speed.
class MultiProcessMapDataComponentSharedArray(DataFlow):
"""
Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
......
......@@ -7,6 +7,8 @@ from contextlib import contextmanager
import operator
import tensorflow as tf
from ..tfutils.common import get_tf_version_number
__all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice',
......@@ -41,12 +43,22 @@ def override_to_local_variable(enable=True):
return getter(name, *args, **kwargs)
orig_vs = tf.get_variable_scope()
# TODO TF1.5 has https://github.com/tensorflow/tensorflow/pull/14390
with tf.variable_scope(
if get_tf_version_number() >= 1.5:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=custom_getter):
with tf.name_scope(orig_vs.original_name_scope):
custom_getter=custom_getter,
auxiliary_name_scope=False):
yield
else:
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
else:
ns = tf.get_variable_scope().original_name_scope
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=custom_getter):
with tf.name_scope(ns + '/'):
yield
else:
yield
......
......@@ -118,9 +118,10 @@ class EnqueueThread(ShareSessionThread):
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self._lock = threading.Lock()
# self._size = queue.size()
def run(self):
with self.default_sess():
with self.default_sess() as sess:
try:
self.reinitialize_dataflow()
while True:
......@@ -130,6 +131,7 @@ class EnqueueThread(ShareSessionThread):
dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp))
# _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass
......
......@@ -124,10 +124,10 @@ class ShareSessionThread(threading.Thread):
def default_sess(self):
if self._sess:
with self._sess.as_default():
yield
yield self._sess
else:
logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name))
yield
yield None
def start(self):
import tensorflow as tf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment