Commit 3e97f126 authored by Yuxin Wu's avatar Yuxin Wu

use tf.train.MonitoredSession. enforce a boundary between graph finalize & session create

parent 224025e3
......@@ -62,9 +62,6 @@ Dependencies:
+ Python 2 or 3
+ TensorFlow >= 1.0.0rc0
+ Python bindings for OpenCV
+ (optional) use tcmalloc if running with large data
```
pip install --user -U git+https://github.com/ppwwyyxx/tensorpack.git
pip install --user -r opt-requirements.txt # (some optional dependencies required by certain submodules, you can install later if prompted)
```
......@@ -23,8 +23,8 @@ from tensorpack.utils.serialize import loads, dumps
from tensorpack.utils.concurrency import LoopThread, ensure_proc_terminate
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
'SimulatorProcessStateExchange',
'TransitionExperience']
class TransitionExperience(object):
......
......@@ -38,7 +38,7 @@ CHANNEL = FRAME_HISTORY * 3
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
LOCAL_TIME_MAX = 5
STEP_PER_EPOCH = 6000
STEPS_PER_EPOCH = 6000
EVAL_EPISODE = 50
BATCH_SIZE = 128
SIMULATOR_PROC = 50
......@@ -150,11 +150,12 @@ class MySimulatorMaster(SimulatorMaster, Callback):
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
def _setup_graph(self):
self.sess = self.trainer.sess
self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predict_funcs(['state'], ['logitsT', 'pred_value'],
PREDICTOR_THREAD), batch_size=15)
self.async_predictor.run()
def _before_train(self):
self.async_predictor.start()
def _on_state(self, state, ident):
def cb(outputs):
......@@ -222,7 +223,7 @@ def get_config():
],
session_config=get_default_sess_config(0.5),
model=M,
steps_per_epoch=STEP_PER_EPOCH,
steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000,
)
......
......@@ -40,7 +40,7 @@ def play_model(cfg):
def eval_with_funcs(predict_funcs, nr_eval):
class Worker(StoppableThread):
class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self._func = func
......@@ -52,6 +52,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
return self._func(*args, **kwargs)
def run(self):
with self.default_sess():
player = get_player(train=False)
while not self.stopped():
try:
......
......@@ -11,7 +11,7 @@ from six.moves import queue
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger, get_tqdm, get_rng
from tensorpack.utils.concurrency import LoopThread
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback
__all__ = ['ExpReplay']
......@@ -75,10 +75,9 @@ class ExpReplay(DataFlow, Callback):
# spawn a separate thread to run policy, can speed up 1.3x
def populate_job_func():
self._populate_job_queue.get()
with self.trainer.sess.as_default():
for _ in range(self.update_frequency):
self._populate_exp()
th = LoopThread(populate_job_func, pausable=False)
th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
th.name = "SimulatorThread"
return th
......
......@@ -21,9 +21,9 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for
n=5, about 7.1% val error after 67k steps (8.6 step/s)
n=18, about 5.95% val error after 80k steps (2.6 step/s)
n=30: a 182-layer network, about 5.6% val error after 51k steps (1.55 step/s)
n=5, about 7.1% val error after 67k steps (15 step/s)
n=18, about 5.95% val error after 80k steps (4.2 step/s)
n=30: a 182-layer network, about 5.6% val error after 51k steps (2.5 step/s)
This model uses the whole training set instead of a train-val split.
To train:
......@@ -131,7 +131,7 @@ def get_data(train_or_test):
imgaug.MapImage(lambda x: x - pp_mean)
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 3, 2)
return ds
......
......@@ -226,18 +226,19 @@ class FeedfreeInferenceRunner(Triggerable):
G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names]
self._sess = self.trainer.sess
# list of list of id
self.inf_to_idxs = dispatcer.get_idx_for_each_entry()
def _trigger(self):
sess = tf.get_default_session()
for inf in self.infs:
inf.before_inference()
with get_tqdm(total=self._size) as pbar:
for _ in range(self._size):
outputs = self._sess.run(fetches=self._output_tensors)
outputs = sess.run(fetches=self._output_tensors)
for inf, idlist in zip(self.infs, self.inf_to_idxs):
inf_output = [outputs[k] for k in idlist]
inf.datapoint(inf_output)
......
......@@ -6,4 +6,4 @@ import cv2 # noqa
import os
os.environ['OPENCV_OPENCL_RUNTIME'] = ''
__version__ = '0.1.5'
__version__ = '0.1.6'
......@@ -23,7 +23,6 @@ class PredictorBase(object):
Base class for all predictors.
Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs)
or just outpus
"""
......@@ -91,25 +90,30 @@ class AsyncPredictorBase(PredictorBase):
class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session. """
def __init__(self, sess, input_tensors, output_tensors, return_input=False):
def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None):
"""
Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None,
will use the default session.
"""
self.session = sess
self.return_input = return_input
self.input_tensors = input_tensors
self.output_tensors = output_tensors
self.sess = sess
def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_tensors, dp))
output = self.session.run(self.output_tensors, feed_dict=feed)
if self.sess is None:
sess = tf.get_default_session()
else:
sess = self.sess
output = sess.run(self.output_tensors, feed_dict=feed)
return output
......@@ -133,7 +137,7 @@ class OfflinePredictor(OnlinePredictor):
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
input_vars, output_vars, config.return_input, sess)
def get_predict_func(config):
......
......@@ -9,9 +9,9 @@ from six.moves import queue, range
import tensorflow as tf
from ..utils import logger
from ..utils.concurrency import DIE, StoppableThread
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model
from .base import OfflinePredictor, AsyncPredictorBase
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
......@@ -73,7 +73,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.predictor(dp)))
class PredictorWorkerThread(StoppableThread):
class PredictorWorkerThread(StoppableThread, ShareSessionThread):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
......@@ -83,6 +83,7 @@ class PredictorWorkerThread(StoppableThread):
self.id = id
def run(self):
with self.default_sess():
while not self.stopped():
batched, futures = self.fetch_batch()
try:
......@@ -137,9 +138,12 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
batch_size (int): the maximum of an internal batch.
"""
assert len(predictors)
self._need_default_sess = False
for k in predictors:
# assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert isinstance(k, OnlinePredictor), type(k)
if k.sess is None:
self._need_default_sess = True
# TODO support predictors.return_input here
assert not k.return_input
self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [
......@@ -153,6 +157,10 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
options.parse_command_line(['--logging=debug'])
def start(self):
if self._need_default_sess:
assert tf.get_default_session() is not None, \
"Not session is bind to predictors, " \
"MultiThreadAsyncPredictor.start() has to be called under a default session!"
for t in self.threads:
t.start()
......
......@@ -41,7 +41,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
['{}{}/'.format(PREDICT_TOWER, k) + n
for n in config.output_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
input_vars, output_vars, config.return_input, self.sess))
def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
......@@ -90,4 +90,4 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
input_vars, output_vars, config.return_input, sess)
......@@ -35,7 +35,6 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
......@@ -53,10 +52,8 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config)
self.config = config
self.model = config.model
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
self.epoch_num = self.config.starting_epoch
self.epoch_num = self.config.starting_epoch - 1
self.local_step = 0
def train(self):
......@@ -131,24 +128,29 @@ class Trainer(object):
describe_model()
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches()
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.summary.merge_all()
logger.info("Setup summaries ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
self.summary_op = tf.summary.merge_all() # XXX not good
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...")
initop = tf.global_variables_initializer()
self.sess.run(initop)
def after_init(_, __):
logger.info("Graph variables initialized.")
scaffold = tf.train.Scaffold(
init_op=tf.global_variables_initializer(),
init_fn=after_init)
logger.info("Finalize the graph, create the session ...")
self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config),
hooks=None)
self.sess = self.monitored_sess._tf_sess()
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
@abstractmethod
def _setup(self):
""" setup Trainer-specific stuff for training"""
......@@ -176,7 +178,7 @@ class Trainer(object):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch):
if self.coord.should_stop():
if self.monitored_sess.should_stop():
return
fetch_data = self.run_step() # implemented by subclass
if fetch_data is None:
......@@ -197,9 +199,8 @@ class Trainer(object):
raise
finally:
callbacks.after_train()
self.coord.request_stop()
self.summary_writer.close()
self.sess.close()
self.monitored_sess.close()
def get_predict_func(self, input_names, output_names):
"""
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import threading
from abc import ABCMeta, abstractmethod
import six
......@@ -12,6 +11,7 @@ from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['InputData', 'FeedfreeInput',
......@@ -72,8 +72,8 @@ class FeedfreeInput(InputData):
"""
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, ds, input_placehdrs):
class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.daemon = True
......@@ -81,8 +81,6 @@ class EnqueueThread(threading.Thread):
self.dataflow = ds
self.queue = queue
self.sess = trainer.sess
self.coord = trainer.coord
self.placehdrs = input_placehdrs
self.op = self.queue.enqueue(self.placehdrs)
......@@ -92,13 +90,11 @@ class EnqueueThread(threading.Thread):
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
with self.default_sess():
try:
self.dataflow.reset_state()
with self.sess.as_default():
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
......@@ -107,12 +103,7 @@ class EnqueueThread(threading.Thread):
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
logger.info("Enqueue Thread Exited.")
logger.info("EnqueueThread Exited.")
class QueueInput(FeedfreeInput):
......@@ -141,8 +132,7 @@ class QueueInput(FeedfreeInput):
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(
trainer, self.queue, self.ds, self.input_placehdrs)
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
......@@ -203,8 +193,7 @@ class BatchQueueInput(FeedfreeInput):
for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(
trainer, self.queue, self.ds, placehdrs_nobatch)
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
......
......@@ -199,7 +199,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)
grad_list = [apply_grad_processors(g, [gradproc]) for g in grad_list]
grad_list = apply_grad_processors(grad_list, [gradproc])
# use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op')
......@@ -216,7 +216,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def f(op=train_op): # avoid late-binding
self.sess.run([op])
next(self.async_step_counter)
next(self.async_step_counter) # atomic due to GIL
th = LoopThread(f)
th.name = "AsyncLoopThread-{}".format(k)
th.pause()
......
......@@ -18,19 +18,20 @@ __all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
class PredictorFactory(object):
""" Make predictors for a trainer"""
def __init__(self, sess, model, towers):
def __init__(self, model, towers):
"""
:param towers: list of gpu relative id
"""
self.sess = sess
self.model = model
self.towers = towers
self.tower_built = False
def get_predictor(self, input_names, output_names, tower):
"""
:param tower: need the kth tower (not the gpu id)
:returns: an online predictor
Args:
tower: need the kth tower (not the gpu id)
Returns:
an online predictor (which has to be used under a default session)
"""
if not self.tower_built:
self._build_predict_tower()
......@@ -53,7 +54,7 @@ class PredictorFactory(object):
output_names = map(get_name_in_tower, output_names)
output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
return OnlinePredictor(raw_input_vars, output_vars)
def _build_predict_tower(self):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
......@@ -76,7 +77,7 @@ class SimpleTrainer(Trainer):
config (TrainConfig): the training config.
"""
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
self._predictor_factory = PredictorFactory(self.model, [0])
if config.dataflow is None:
self._input_method = config.data
assert isinstance(self._input_method, FeedInput), type(self._input_method)
......@@ -118,7 +119,7 @@ class MultiPredictorTowerTrainer(Trainer):
def _setup_predictor_factory(self):
# by default, use the first training gpu for prediction
self._predictor_factory = PredictorFactory(
self.sess, self.model, self.config.predict_tower)
self.model, self.config.predict_tower)
def get_predict_func(self, input_names, output_names, tower=0):
"""
......
......@@ -21,7 +21,8 @@ else:
import subprocess
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
__all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread',
'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'mask_sigint', 'start_proc_mask_signal']
......@@ -97,6 +98,39 @@ class LoopThread(StoppableThread):
self._lock.release()
class ShareSessionThread(threading.Thread):
""" A wrapper around thread so that the thread
uses the default session at "start()" time.
"""
def __init__(self, th=None):
"""
Args:
th (threading.Thread or None):
"""
super(ShareSessionThread, self).__init__()
if th is not None:
assert isinstance(th, threading.Thread), th
self._th = th
self.name = th.name
self.daemon = th.daemon
@contextmanager
def default_sess(self):
with self._sess.as_default():
yield
def start(self):
import tensorflow as tf
self._sess = tf.get_default_session()
super(ShareSessionThread, self).start()
def run(self):
if not self._th:
raise NotImplementedError()
with self._sess.as_default():
self._th.run()
class DIE(object):
""" A placeholder class indicating end of queue """
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment