Commit 3e97f126 authored by Yuxin Wu's avatar Yuxin Wu

use tf.train.MonitoredSession. enforce a boundary between graph finalize & session create

parent 224025e3
...@@ -62,9 +62,6 @@ Dependencies: ...@@ -62,9 +62,6 @@ Dependencies:
+ Python 2 or 3 + Python 2 or 3
+ TensorFlow >= 1.0.0rc0 + TensorFlow >= 1.0.0rc0
+ Python bindings for OpenCV + Python bindings for OpenCV
+ (optional) use tcmalloc if running with large data
``` ```
pip install --user -U git+https://github.com/ppwwyyxx/tensorpack.git pip install --user -U git+https://github.com/ppwwyyxx/tensorpack.git
pip install --user -r opt-requirements.txt # (some optional dependencies required by certain submodules, you can install later if prompted)
``` ```
...@@ -23,8 +23,8 @@ from tensorpack.utils.serialize import loads, dumps ...@@ -23,8 +23,8 @@ from tensorpack.utils.serialize import loads, dumps
from tensorpack.utils.concurrency import LoopThread, ensure_proc_terminate from tensorpack.utils.concurrency import LoopThread, ensure_proc_terminate
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight', 'SimulatorProcessStateExchange',
'TransitionExperience', 'WeightSync'] 'TransitionExperience']
class TransitionExperience(object): class TransitionExperience(object):
......
...@@ -38,7 +38,7 @@ CHANNEL = FRAME_HISTORY * 3 ...@@ -38,7 +38,7 @@ CHANNEL = FRAME_HISTORY * 3
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,) IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
LOCAL_TIME_MAX = 5 LOCAL_TIME_MAX = 5
STEP_PER_EPOCH = 6000 STEPS_PER_EPOCH = 6000
EVAL_EPISODE = 50 EVAL_EPISODE = 50
BATCH_SIZE = 128 BATCH_SIZE = 128
SIMULATOR_PROC = 50 SIMULATOR_PROC = 50
...@@ -150,11 +150,12 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -150,11 +150,12 @@ class MySimulatorMaster(SimulatorMaster, Callback):
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2) self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
def _setup_graph(self): def _setup_graph(self):
self.sess = self.trainer.sess
self.async_predictor = MultiThreadAsyncPredictor( self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predict_funcs(['state'], ['logitsT', 'pred_value'], self.trainer.get_predict_funcs(['state'], ['logitsT', 'pred_value'],
PREDICTOR_THREAD), batch_size=15) PREDICTOR_THREAD), batch_size=15)
self.async_predictor.run()
def _before_train(self):
self.async_predictor.start()
def _on_state(self, state, ident): def _on_state(self, state, ident):
def cb(outputs): def cb(outputs):
...@@ -222,7 +223,7 @@ def get_config(): ...@@ -222,7 +223,7 @@ def get_config():
], ],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=M, model=M,
steps_per_epoch=STEP_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000, max_epoch=1000,
) )
......
...@@ -40,7 +40,7 @@ def play_model(cfg): ...@@ -40,7 +40,7 @@ def play_model(cfg):
def eval_with_funcs(predict_funcs, nr_eval): def eval_with_funcs(predict_funcs, nr_eval):
class Worker(StoppableThread): class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue): def __init__(self, func, queue):
super(Worker, self).__init__() super(Worker, self).__init__()
self._func = func self._func = func
...@@ -52,14 +52,15 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -52,14 +52,15 @@ def eval_with_funcs(predict_funcs, nr_eval):
return self._func(*args, **kwargs) return self._func(*args, **kwargs)
def run(self): def run(self):
player = get_player(train=False) with self.default_sess():
while not self.stopped(): player = get_player(train=False)
try: while not self.stopped():
score = play_one_episode(player, self.func) try:
# print "Score, ", score score = play_one_episode(player, self.func)
except RuntimeError: # print "Score, ", score
return except RuntimeError:
self.queue_put_stoppable(self.q, score) return
self.queue_put_stoppable(self.q, score)
q = queue.Queue() q = queue.Queue()
threads = [Worker(f, q) for f in predict_funcs] threads = [Worker(f, q) for f in predict_funcs]
......
...@@ -11,7 +11,7 @@ from six.moves import queue ...@@ -11,7 +11,7 @@ from six.moves import queue
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger, get_tqdm, get_rng from tensorpack.utils import logger, get_tqdm, get_rng
from tensorpack.utils.concurrency import LoopThread from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
__all__ = ['ExpReplay'] __all__ = ['ExpReplay']
...@@ -75,10 +75,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -75,10 +75,9 @@ class ExpReplay(DataFlow, Callback):
# spawn a separate thread to run policy, can speed up 1.3x # spawn a separate thread to run policy, can speed up 1.3x
def populate_job_func(): def populate_job_func():
self._populate_job_queue.get() self._populate_job_queue.get()
with self.trainer.sess.as_default(): for _ in range(self.update_frequency):
for _ in range(self.update_frequency): self._populate_exp()
self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
th = LoopThread(populate_job_func, pausable=False)
th.name = "SimulatorThread" th.name = "SimulatorThread"
return th return th
......
...@@ -21,9 +21,9 @@ This implementation uses the variants proposed in: ...@@ -21,9 +21,9 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027 Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for I can reproduce the results on 2 TitanX for
n=5, about 7.1% val error after 67k steps (8.6 step/s) n=5, about 7.1% val error after 67k steps (15 step/s)
n=18, about 5.95% val error after 80k steps (2.6 step/s) n=18, about 5.95% val error after 80k steps (4.2 step/s)
n=30: a 182-layer network, about 5.6% val error after 51k steps (1.55 step/s) n=30: a 182-layer network, about 5.6% val error after 51k steps (2.5 step/s)
This model uses the whole training set instead of a train-val split. This model uses the whole training set instead of a train-val split.
To train: To train:
...@@ -131,7 +131,7 @@ def get_data(train_or_test): ...@@ -131,7 +131,7 @@ def get_data(train_or_test):
imgaug.MapImage(lambda x: x - pp_mean) imgaug.MapImage(lambda x: x - pp_mean)
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 3, 2) ds = PrefetchData(ds, 3, 2)
return ds return ds
......
...@@ -226,18 +226,19 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -226,18 +226,19 @@ class FeedfreeInferenceRunner(Triggerable):
G = tf.get_default_graph() G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name( self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names] self._tower_prefix + '/' + n) for n in all_names]
self._sess = self.trainer.sess
# list of list of id # list of list of id
self.inf_to_idxs = dispatcer.get_idx_for_each_entry() self.inf_to_idxs = dispatcer.get_idx_for_each_entry()
def _trigger(self): def _trigger(self):
sess = tf.get_default_session()
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
with get_tqdm(total=self._size) as pbar: with get_tqdm(total=self._size) as pbar:
for _ in range(self._size): for _ in range(self._size):
outputs = self._sess.run(fetches=self._output_tensors) outputs = sess.run(fetches=self._output_tensors)
for inf, idlist in zip(self.infs, self.inf_to_idxs): for inf, idlist in zip(self.infs, self.inf_to_idxs):
inf_output = [outputs[k] for k in idlist] inf_output = [outputs[k] for k in idlist]
inf.datapoint(inf_output) inf.datapoint(inf_output)
......
...@@ -6,4 +6,4 @@ import cv2 # noqa ...@@ -6,4 +6,4 @@ import cv2 # noqa
import os import os
os.environ['OPENCV_OPENCL_RUNTIME'] = '' os.environ['OPENCV_OPENCL_RUNTIME'] = ''
__version__ = '0.1.5' __version__ = '0.1.6'
...@@ -23,7 +23,6 @@ class PredictorBase(object): ...@@ -23,7 +23,6 @@ class PredictorBase(object):
Base class for all predictors. Base class for all predictors.
Attributes: Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs) return_input (bool): whether the call will also return (inputs, outputs)
or just outpus or just outpus
""" """
...@@ -91,25 +90,30 @@ class AsyncPredictorBase(PredictorBase): ...@@ -91,25 +90,30 @@ class AsyncPredictorBase(PredictorBase):
class OnlinePredictor(PredictorBase): class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session. """ """ A predictor which directly use an existing session. """
def __init__(self, sess, input_tensors, output_tensors, return_input=False): def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None):
""" """
Args: Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names. input_tensors (list): list of names.
output_tensors (list): list of names. output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`. return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None,
will use the default session.
""" """
self.session = sess
self.return_input = return_input self.return_input = return_input
self.input_tensors = input_tensors self.input_tensors = input_tensors
self.output_tensors = output_tensors self.output_tensors = output_tensors
self.sess = sess
def _do_call(self, dp): def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \ assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_tensors)) "{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_tensors, dp)) feed = dict(zip(self.input_tensors, dp))
output = self.session.run(self.output_tensors, feed_dict=feed) if self.sess is None:
sess = tf.get_default_session()
else:
sess = self.sess
output = sess.run(self.output_tensors, feed_dict=feed)
return output return output
...@@ -133,7 +137,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -133,7 +137,7 @@ class OfflinePredictor(OnlinePredictor):
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
super(OfflinePredictor, self).__init__( super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input) input_vars, output_vars, config.return_input, sess)
def get_predict_func(config): def get_predict_func(config):
......
...@@ -9,9 +9,9 @@ from six.moves import queue, range ...@@ -9,9 +9,9 @@ from six.moves import queue, range
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..utils.concurrency import DIE, StoppableThread from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from .base import OfflinePredictor, AsyncPredictorBase from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor'] 'MultiThreadAsyncPredictor']
...@@ -73,7 +73,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -73,7 +73,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.predictor(dp))) self.outqueue.put((tid, self.predictor(dp)))
class PredictorWorkerThread(StoppableThread): class PredictorWorkerThread(StoppableThread, ShareSessionThread):
def __init__(self, queue, pred_func, id, batch_size=5): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
...@@ -83,25 +83,26 @@ class PredictorWorkerThread(StoppableThread): ...@@ -83,25 +83,26 @@ class PredictorWorkerThread(StoppableThread):
self.id = id self.id = id
def run(self): def run(self):
while not self.stopped(): with self.default_sess():
batched, futures = self.fetch_batch() while not self.stopped():
try: batched, futures = self.fetch_batch()
outputs = self.func(batched) try:
except tf.errors.CancelledError: outputs = self.func(batched)
for f in futures: except tf.errors.CancelledError:
f.cancel() for f in futures:
logger.warn("In PredictorWorkerThread id={}, call was cancelled.".format(self.id)) f.cancel()
return logger.warn("In PredictorWorkerThread id={}, call was cancelled.".format(self.id))
# print "Worker {} batched {} Queue {}".format( return
# self.id, len(futures), self.queue.qsize()) # print "Worker {} batched {} Queue {}".format(
# debug, for speed testing # self.id, len(futures), self.queue.qsize())
# if not hasattr(self, 'xxx'): # debug, for speed testing
# self.xxx = outputs = self.func(batched) # if not hasattr(self, 'xxx'):
# else: # self.xxx = outputs = self.func(batched)
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])] # else:
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs]) for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
def fetch_batch(self): def fetch_batch(self):
""" Fetch a batch of data without waiting""" """ Fetch a batch of data without waiting"""
...@@ -137,9 +138,12 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -137,9 +138,12 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
batch_size (int): the maximum of an internal batch. batch_size (int): the maximum of an internal batch.
""" """
assert len(predictors) assert len(predictors)
self._need_default_sess = False
for k in predictors: for k in predictors:
# assert isinstance(k, OnlinePredictor), type(k) assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here if k.sess is None:
self._need_default_sess = True
# TODO support predictors.return_input here
assert not k.return_input assert not k.return_input
self.input_queue = queue.Queue(maxsize=len(predictors) * 100) self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [ self.threads = [
...@@ -153,6 +157,10 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -153,6 +157,10 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
options.parse_command_line(['--logging=debug']) options.parse_command_line(['--logging=debug'])
def start(self): def start(self):
if self._need_default_sess:
assert tf.get_default_session() is not None, \
"Not session is bind to predictors, " \
"MultiThreadAsyncPredictor.start() has to be called under a default session!"
for t in self.threads: for t in self.threads:
t.start() t.start()
......
...@@ -41,7 +41,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -41,7 +41,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
['{}{}/'.format(PREDICT_TOWER, k) + n ['{}{}/'.format(PREDICT_TOWER, k) + n
for n in config.output_names]) for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input)) input_vars, output_vars, config.return_input, self.sess))
def _do_call(self, dp): def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface # use the first tower for compatible PredictorBase interface
...@@ -90,4 +90,4 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -90,4 +90,4 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input_vars = get_tensors_by_names(input_var_names) input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess) config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__( super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input) input_vars, output_vars, config.return_input, sess)
...@@ -35,7 +35,6 @@ class Trainer(object): ...@@ -35,7 +35,6 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer. config (TrainConfig): the config used in this trainer.
model (ModelDesc) model (ModelDesc)
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
stat_holder (StatHolder) stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter) summary_writer (tf.summary.FileWriter)
...@@ -53,10 +52,8 @@ class Trainer(object): ...@@ -53,10 +52,8 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self.config = config
self.model = config.model self.model = config.model
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
self.epoch_num = self.config.starting_epoch self.epoch_num = self.config.starting_epoch - 1
self.local_step = 0 self.local_step = 0
def train(self): def train(self):
...@@ -131,24 +128,29 @@ class Trainer(object): ...@@ -131,24 +128,29 @@ class Trainer(object):
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches() self._extra_fetches = self.config.callbacks.extra_fetches()
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=self.sess.graph) logger.info("Setup summaries ...")
self.summary_op = tf.summary.merge_all() self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
self.summary_op = tf.summary.merge_all() # XXX not good
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...") def after_init(_, __):
initop = tf.global_variables_initializer() logger.info("Graph variables initialized.")
self.sess.run(initop) scaffold = tf.train.Scaffold(
init_op=tf.global_variables_initializer(),
init_fn=after_init)
logger.info("Finalize the graph, create the session ...")
self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config),
hooks=None)
self.sess = self.monitored_sess._tf_sess()
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
@abstractmethod @abstractmethod
def _setup(self): def _setup(self):
""" setup Trainer-specific stuff for training""" """ setup Trainer-specific stuff for training"""
...@@ -176,7 +178,7 @@ class Trainer(object): ...@@ -176,7 +178,7 @@ class Trainer(object):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self.coord.should_stop(): if self.monitored_sess.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass fetch_data = self.run_step() # implemented by subclass
if fetch_data is None: if fetch_data is None:
...@@ -197,9 +199,8 @@ class Trainer(object): ...@@ -197,9 +199,8 @@ class Trainer(object):
raise raise
finally: finally:
callbacks.after_train() callbacks.after_train()
self.coord.request_stop()
self.summary_writer.close() self.summary_writer.close()
self.sess.close() self.monitored_sess.close()
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
""" """
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import threading
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six import six
...@@ -12,6 +11,7 @@ from ..dataflow import DataFlow, RepeatedData ...@@ -12,6 +11,7 @@ from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
__all__ = ['InputData', 'FeedfreeInput', __all__ = ['InputData', 'FeedfreeInput',
...@@ -72,8 +72,8 @@ class FeedfreeInput(InputData): ...@@ -72,8 +72,8 @@ class FeedfreeInput(InputData):
""" """
class EnqueueThread(threading.Thread): class EnqueueThread(ShareSessionThread):
def __init__(self, trainer, queue, ds, input_placehdrs): def __init__(self, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread' self.name = 'EnqueueThread'
self.daemon = True self.daemon = True
...@@ -81,8 +81,6 @@ class EnqueueThread(threading.Thread): ...@@ -81,8 +81,6 @@ class EnqueueThread(threading.Thread):
self.dataflow = ds self.dataflow = ds
self.queue = queue self.queue = queue
self.sess = trainer.sess
self.coord = trainer.coord
self.placehdrs = input_placehdrs self.placehdrs = input_placehdrs
self.op = self.queue.enqueue(self.placehdrs) self.op = self.queue.enqueue(self.placehdrs)
...@@ -92,27 +90,20 @@ class EnqueueThread(threading.Thread): ...@@ -92,27 +90,20 @@ class EnqueueThread(threading.Thread):
self.size_op, tf.float32, name='input_queue_size')) self.size_op, tf.float32, name='input_queue_size'))
def run(self): def run(self):
try: with self.default_sess():
self.dataflow.reset_state() try:
with self.sess.as_default(): self.dataflow.reset_state()
while True: while True:
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] # print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError: except tf.errors.CancelledError:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass pass
logger.info("Enqueue Thread Exited.") except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
logger.info("EnqueueThread Exited.")
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
...@@ -141,8 +132,7 @@ class QueueInput(FeedfreeInput): ...@@ -141,8 +132,7 @@ class QueueInput(FeedfreeInput):
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs], 50, [x.dtype for x in self.input_placehdrs],
name='input_queue') name='input_queue')
self.thread = EnqueueThread( self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
trainer, self.queue, self.ds, self.input_placehdrs)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self): def _get_input_tensors(self):
...@@ -203,8 +193,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -203,8 +193,7 @@ class BatchQueueInput(FeedfreeInput):
for shp in self.queue.shapes: for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread( self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
trainer, self.queue, self.ds, placehdrs_nobatch)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self): def _get_input_tensors(self):
......
...@@ -199,7 +199,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -199,7 +199,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False) gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)
grad_list = [apply_grad_processors(g, [gradproc]) for g in grad_list] grad_list = apply_grad_processors(grad_list, [gradproc])
# use grad from the first tower for iteration in main thread # use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op') self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op')
...@@ -216,7 +216,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -216,7 +216,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def f(op=train_op): # avoid late-binding def f(op=train_op): # avoid late-binding
self.sess.run([op]) self.sess.run([op])
next(self.async_step_counter) next(self.async_step_counter) # atomic due to GIL
th = LoopThread(f) th = LoopThread(f)
th.name = "AsyncLoopThread-{}".format(k) th.name = "AsyncLoopThread-{}".format(k)
th.pause() th.pause()
......
...@@ -18,19 +18,20 @@ __all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer'] ...@@ -18,19 +18,20 @@ __all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
def __init__(self, sess, model, towers): def __init__(self, model, towers):
""" """
:param towers: list of gpu relative id :param towers: list of gpu relative id
""" """
self.sess = sess
self.model = model self.model = model
self.towers = towers self.towers = towers
self.tower_built = False self.tower_built = False
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" """
:param tower: need the kth tower (not the gpu id) Args:
:returns: an online predictor tower: need the kth tower (not the gpu id)
Returns:
an online predictor (which has to be used under a default session)
""" """
if not self.tower_built: if not self.tower_built:
self._build_predict_tower() self._build_predict_tower()
...@@ -53,7 +54,7 @@ class PredictorFactory(object): ...@@ -53,7 +54,7 @@ class PredictorFactory(object):
output_names = map(get_name_in_tower, output_names) output_names = map(get_name_in_tower, output_names)
output_vars = get_tensors_by_names(output_names) output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(raw_input_vars, output_vars)
def _build_predict_tower(self): def _build_predict_tower(self):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope # build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
...@@ -76,7 +77,7 @@ class SimpleTrainer(Trainer): ...@@ -76,7 +77,7 @@ class SimpleTrainer(Trainer):
config (TrainConfig): the training config. config (TrainConfig): the training config.
""" """
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.model, [0])
if config.dataflow is None: if config.dataflow is None:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, FeedInput), type(self._input_method) assert isinstance(self._input_method, FeedInput), type(self._input_method)
...@@ -118,7 +119,7 @@ class MultiPredictorTowerTrainer(Trainer): ...@@ -118,7 +119,7 @@ class MultiPredictorTowerTrainer(Trainer):
def _setup_predictor_factory(self): def _setup_predictor_factory(self):
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.sess, self.model, self.config.predict_tower) self.model, self.config.predict_tower)
def get_predict_func(self, input_names, output_names, tower=0): def get_predict_func(self, input_names, output_names, tower=0):
""" """
......
...@@ -21,7 +21,8 @@ else: ...@@ -21,7 +21,8 @@ else:
import subprocess import subprocess
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread',
'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE', 'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'mask_sigint', 'start_proc_mask_signal'] 'mask_sigint', 'start_proc_mask_signal']
...@@ -97,6 +98,39 @@ class LoopThread(StoppableThread): ...@@ -97,6 +98,39 @@ class LoopThread(StoppableThread):
self._lock.release() self._lock.release()
class ShareSessionThread(threading.Thread):
""" A wrapper around thread so that the thread
uses the default session at "start()" time.
"""
def __init__(self, th=None):
"""
Args:
th (threading.Thread or None):
"""
super(ShareSessionThread, self).__init__()
if th is not None:
assert isinstance(th, threading.Thread), th
self._th = th
self.name = th.name
self.daemon = th.daemon
@contextmanager
def default_sess(self):
with self._sess.as_default():
yield
def start(self):
import tensorflow as tf
self._sess = tf.get_default_session()
super(ShareSessionThread, self).start()
def run(self):
if not self._th:
raise NotImplementedError()
with self._sess.as_default():
self._th.run()
class DIE(object): class DIE(object):
""" A placeholder class indicating end of queue """ """ A placeholder class indicating end of queue """
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment