Commit be4759be authored by Yuxin Wu's avatar Yuxin Wu

misc update

parent 58c2779f
...@@ -20,8 +20,7 @@ from .parallel import ( ...@@ -20,8 +20,7 @@ from .parallel import (
__all__ = ['ThreadedMapData', 'MultiThreadMapData', __all__ = ['ThreadedMapData', 'MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ', 'MultiProcessMapData', 'MultiProcessMapDataZMQ']
'MultiProcessMapDataComponentSharedArray']
class _ParallelMapData(ProxyDataFlow): class _ParallelMapData(ProxyDataFlow):
...@@ -302,6 +301,7 @@ def _pool_map(data): ...@@ -302,6 +301,7 @@ def _pool_map(data):
return WORKER_ID return WORKER_ID
# TODO shutdown pool, improve speed.
class MultiProcessMapDataComponentSharedArray(DataFlow): class MultiProcessMapDataComponentSharedArray(DataFlow):
""" """
Similar to :class:`MapDataComponent`, but perform IPC by shared memory, Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
......
...@@ -7,6 +7,8 @@ from contextlib import contextmanager ...@@ -7,6 +7,8 @@ from contextlib import contextmanager
import operator import operator
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_tf_version_number
__all__ = ['LeastLoadedDeviceSetter', __all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice', 'OverrideCachingDevice',
...@@ -41,12 +43,22 @@ def override_to_local_variable(enable=True): ...@@ -41,12 +43,22 @@ def override_to_local_variable(enable=True):
return getter(name, *args, **kwargs) return getter(name, *args, **kwargs)
orig_vs = tf.get_variable_scope() orig_vs = tf.get_variable_scope()
# TODO TF1.5 has https://github.com/tensorflow/tensorflow/pull/14390 if get_tf_version_number() >= 1.5:
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), tf.get_variable_scope(),
custom_getter=custom_getter): custom_getter=custom_getter,
with tf.name_scope(orig_vs.original_name_scope): auxiliary_name_scope=False):
yield yield
else:
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
else:
ns = tf.get_variable_scope().original_name_scope
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=custom_getter):
with tf.name_scope(ns + '/'):
yield
else: else:
yield yield
......
...@@ -118,9 +118,10 @@ class EnqueueThread(ShareSessionThread): ...@@ -118,9 +118,10 @@ class EnqueueThread(ShareSessionThread):
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self._lock = threading.Lock() self._lock = threading.Lock()
# self._size = queue.size()
def run(self): def run(self):
with self.default_sess(): with self.default_sess() as sess:
try: try:
self.reinitialize_dataflow() self.reinitialize_dataflow()
while True: while True:
...@@ -130,6 +131,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -130,6 +131,7 @@ class EnqueueThread(ShareSessionThread):
dp = next(self._itr) dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
# _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass pass
......
...@@ -124,10 +124,10 @@ class ShareSessionThread(threading.Thread): ...@@ -124,10 +124,10 @@ class ShareSessionThread(threading.Thread):
def default_sess(self): def default_sess(self):
if self._sess: if self._sess:
with self._sess.as_default(): with self._sess.as_default():
yield yield self._sess
else: else:
logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name)) logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name))
yield yield None
def start(self): def start(self):
import tensorflow as tf import tensorflow as tf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment