Commit 39afd64d authored by Yuxin Wu's avatar Yuxin Wu

Reset dataflow before before_train, to avoid forking session (#494)

parent c0a81d51
......@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
# File: box_ops.py
import os
import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import get_default_sess_config
......@@ -74,6 +75,7 @@ def get_iou_callable():
"""
Get a pairwise box iou callable.
"""
os.environ['CUDA_VISIBLE_DEVICES'] = ''
with tf.Graph().as_default(), tf.device('/cpu:0'):
A = tf.placeholder(tf.float32, shape=[None, 4])
B = tf.placeholder(tf.float32, shape=[None, 4])
......
......@@ -21,7 +21,7 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated
from ..callbacks.base import Callback
from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
__all__ = ['PlaceholderInput', 'FeedInput',
......@@ -33,6 +33,10 @@ __all__ = ['PlaceholderInput', 'FeedInput',
'StagingInput']
def _get_reset_callback(df):
return CallbackFactory(setup_graph=lambda _: df.reset_state())
class PlaceholderInput(InputSource):
"""
Just produce placeholders as input tensors.
......@@ -99,7 +103,7 @@ class FeedInput(InputSource):
self._cb._reset()
def _get_callbacks(self):
return [self._cb]
return [self._cb, _get_reset_callback(self._iter_ds)]
class FeedfreeInput(InputSource):
......@@ -116,10 +120,8 @@ class EnqueueThread(ShareSessionThread):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread ' + queue.name
self.daemon = True
self.dataflow = RepeatedData(ds, -1)
self.dataflow = ds
self.queue = queue
self.placehdrs = placehdrs
self.op = self.queue.enqueue(self.placehdrs)
......@@ -130,7 +132,7 @@ class EnqueueThread(ShareSessionThread):
def run(self):
with self.default_sess():
try:
self.reset_dataflow()
self._itr = self.dataflow.get_data()
while True:
# pausable loop
self._lock.acquire()
......@@ -182,6 +184,7 @@ class QueueInput(FeedfreeInput):
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
self._inf_ds = RepeatedData(ds, -1)
def _size(self):
return self.ds.size()
......@@ -196,7 +199,7 @@ class QueueInput(FeedfreeInput):
50, [x.dtype for x in self._input_placehdrs],
name='input_queue')
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
self.thread = EnqueueThread(self.queue, self._inf_ds, self._input_placehdrs)
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')
......@@ -236,7 +239,7 @@ class QueueInput(FeedfreeInput):
from ..callbacks.concurrency import StartProcOrThread
cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb, self._create_ema_callback()]
return [cb, self._create_ema_callback(), _get_reset_callback(self._inf_ds)]
def _get_input_tensors(self):
with tf.device('/cpu:0'), self.cached_name_scope():
......@@ -299,7 +302,7 @@ class BatchQueueInput(QueueInput):
for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
self.thread = EnqueueThread(self.queue, self._inf_ds, placehdrs_nobatch)
def _get_input_tensors(self):
with tf.device('/cpu:0'), self.cached_name_scope():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment