Commit 39afd64d authored by Yuxin Wu's avatar Yuxin Wu

Reset dataflow before before_train, to avoid forking session (#494)

parent c0a81d51
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: box_ops.py # File: box_ops.py
import os
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import get_default_sess_config from tensorpack.tfutils import get_default_sess_config
...@@ -74,6 +75,7 @@ def get_iou_callable(): ...@@ -74,6 +75,7 @@ def get_iou_callable():
""" """
Get a pairwise box iou callable. Get a pairwise box iou callable.
""" """
os.environ['CUDA_VISIBLE_DEVICES'] = ''
with tf.Graph().as_default(), tf.device('/cpu:0'): with tf.Graph().as_default(), tf.device('/cpu:0'):
A = tf.placeholder(tf.float32, shape=[None, 4]) A = tf.placeholder(tf.float32, shape=[None, 4])
B = tf.placeholder(tf.float32, shape=[None, 4]) B = tf.placeholder(tf.float32, shape=[None, 4])
......
...@@ -21,7 +21,7 @@ from ..tfutils.tower import get_current_tower_context ...@@ -21,7 +21,7 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..callbacks.base import Callback from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
__all__ = ['PlaceholderInput', 'FeedInput', __all__ = ['PlaceholderInput', 'FeedInput',
...@@ -33,6 +33,10 @@ __all__ = ['PlaceholderInput', 'FeedInput', ...@@ -33,6 +33,10 @@ __all__ = ['PlaceholderInput', 'FeedInput',
'StagingInput'] 'StagingInput']
def _get_reset_callback(df):
return CallbackFactory(setup_graph=lambda _: df.reset_state())
class PlaceholderInput(InputSource): class PlaceholderInput(InputSource):
""" """
Just produce placeholders as input tensors. Just produce placeholders as input tensors.
...@@ -99,7 +103,7 @@ class FeedInput(InputSource): ...@@ -99,7 +103,7 @@ class FeedInput(InputSource):
self._cb._reset() self._cb._reset()
def _get_callbacks(self): def _get_callbacks(self):
return [self._cb] return [self._cb, _get_reset_callback(self._iter_ds)]
class FeedfreeInput(InputSource): class FeedfreeInput(InputSource):
...@@ -116,10 +120,8 @@ class EnqueueThread(ShareSessionThread): ...@@ -116,10 +120,8 @@ class EnqueueThread(ShareSessionThread):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread ' + queue.name self.name = 'EnqueueThread ' + queue.name
self.daemon = True self.daemon = True
self.dataflow = ds
self.dataflow = RepeatedData(ds, -1)
self.queue = queue self.queue = queue
self.placehdrs = placehdrs self.placehdrs = placehdrs
self.op = self.queue.enqueue(self.placehdrs) self.op = self.queue.enqueue(self.placehdrs)
...@@ -130,7 +132,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -130,7 +132,7 @@ class EnqueueThread(ShareSessionThread):
def run(self): def run(self):
with self.default_sess(): with self.default_sess():
try: try:
self.reset_dataflow() self._itr = self.dataflow.get_data()
while True: while True:
# pausable loop # pausable loop
self._lock.acquire() self._lock.acquire()
...@@ -182,6 +184,7 @@ class QueueInput(FeedfreeInput): ...@@ -182,6 +184,7 @@ class QueueInput(FeedfreeInput):
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
self.ds = ds self.ds = ds
self._inf_ds = RepeatedData(ds, -1)
def _size(self): def _size(self):
return self.ds.size() return self.ds.size()
...@@ -196,7 +199,7 @@ class QueueInput(FeedfreeInput): ...@@ -196,7 +199,7 @@ class QueueInput(FeedfreeInput):
50, [x.dtype for x in self._input_placehdrs], 50, [x.dtype for x in self._input_placehdrs],
name='input_queue') name='input_queue')
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name)) logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs) self.thread = EnqueueThread(self.queue, self._inf_ds, self._input_placehdrs)
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset') self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')
...@@ -236,7 +239,7 @@ class QueueInput(FeedfreeInput): ...@@ -236,7 +239,7 @@ class QueueInput(FeedfreeInput):
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
cb = StartProcOrThread(self.thread) cb = StartProcOrThread(self.thread)
cb.chief_only = False cb.chief_only = False
return [cb, self._create_ema_callback()] return [cb, self._create_ema_callback(), _get_reset_callback(self._inf_ds)]
def _get_input_tensors(self): def _get_input_tensors(self):
with tf.device('/cpu:0'), self.cached_name_scope(): with tf.device('/cpu:0'), self.cached_name_scope():
...@@ -299,7 +302,7 @@ class BatchQueueInput(QueueInput): ...@@ -299,7 +302,7 @@ class BatchQueueInput(QueueInput):
for shp in self.queue.shapes: for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self._inf_ds, placehdrs_nobatch)
def _get_input_tensors(self): def _get_input_tensors(self):
with tf.device('/cpu:0'), self.cached_name_scope(): with tf.device('/cpu:0'), self.cached_name_scope():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment