Commit fe2b4f97 authored by Yuxin Wu's avatar Yuxin Wu

code clean-up in predict/

parent 8df83a93
......@@ -255,7 +255,6 @@ def run_image(model, sess_init, inputs):
pred_config = PredictConfig(
model=model,
session_init=sess_init,
session_config=get_default_sess_config(0.9),
input_names=['input'],
output_names=['output']
)
......
......@@ -125,7 +125,6 @@ def run_image(model, sess_init, inputs):
pred_config = PredictConfig(
model=model,
session_init=sess_init,
session_config=get_default_sess_config(0.9),
input_names=['input'],
output_names=['output']
)
......
......@@ -5,14 +5,14 @@
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import six
from ..tfutils.common import get_op_or_tensor_by_name, get_global_step_value
from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
@six.add_metaclass(ABCMeta)
class Callback(object):
""" Base class for all callbacks
""" Base class for all callbacks.
Attributes:
epoch_num(int): the number of the current epoch.
......@@ -50,7 +50,6 @@ class Callback(object):
pass
def before_train(self):
self._starting_step = get_global_step_value()
self._before_train()
def _before_train(self):
......
......@@ -154,6 +154,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
yield dp
except zmq.ContextTerminated:
logger.info("ContextTerminated in Master Prefetch Process")
return
except:
raise
......
......@@ -88,7 +88,8 @@ class AsyncPredictorBase(PredictorBase):
class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session. """
""" A predictor which directly use an existing session and given tensors.
"""
def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None):
......@@ -131,13 +132,13 @@ class OfflinePredictor(OnlinePredictor):
with TowerContext('', False):
config.model.build_graph(input_placehdrs)
input_vars = get_tensors_by_names(config.input_names)
output_vars = get_tensors_by_names(config.output_names)
input_tensors = get_tensors_by_names(config.input_names)
output_tensors = get_tensors_by_names(config.output_names)
sess = tf.Session(config=config.session_config)
sess = config.session_creator.create_session()
config.session_init.init(sess)
super(OfflinePredictor, self).__init__(
input_vars, output_vars, config.return_input, sess)
input_tensors, output_tensors, config.return_input, sess)
def get_predict_func(config):
......@@ -149,6 +150,8 @@ def get_predict_func(config):
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
"""
Build graph on each tower.
Args:
build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
......
......@@ -8,7 +8,7 @@ import six
from six.moves import queue, range
import tensorflow as tf
from ..utils import logger
from ..utils import logger, deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
......@@ -27,6 +27,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
config (PredictConfig): the config to use.
"""
super(MultiProcessPredictWorker, self).__init__()
self.name = "MultiProcessPredictWorker-{}".format(idx)
self.idx = idx
self.config = config
......@@ -76,6 +77,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class PredictorWorkerThread(StoppableThread, ShareSessionThread):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.name = "PredictorWorkerThread-{}".format(id)
self.queue = queue
self.func = pred_func
self.daemon = True
......@@ -112,22 +114,20 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
for k in range(nr_input_var):
batched[k].append(inp[k])
futures.append(f)
cnt = 1
while cnt < self.batch_size:
while len(futures) < self.batch_size:
try:
inp, f = self.queue.get_nowait()
for k in range(nr_input_var):
batched[k].append(inp[k])
futures.append(f)
except queue.Empty:
break
cnt += 1
break # do not wait
return batched, futures
class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithread online async predictor which runs a list of PredictorBase.
An multithread online async predictor which runs a list of OnlinePredictor.
It would do an extra batching internally.
"""
......@@ -164,6 +164,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
for t in self.threads:
t.start()
@deprecated("Use 'start()' instead!", "2017-03-11")
def run(self): # temporarily for back-compatibility
self.start()
......
......@@ -5,50 +5,60 @@
import six
from ..models import ModelDesc
from ..utils import log_deprecated
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSession
__all__ = ['PredictConfig']
class PredictConfig(object):
def __init__(self, model, session_init=None,
session_config=get_default_sess_config(0.4),
def __init__(self, model,
session_creator=None,
session_init=None,
session_config=None,
input_names=None,
output_names=None,
return_input=False):
"""
Args:
model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
session_config]
input_names (list): a list of input tensor names. Defaults to all
inputs of the model.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
"""
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.model = model
assert_type(self.model, ModelDesc)
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = session_config
if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit)
if session_creator is None:
if session_config is not None:
log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20")
self.session_creator = NewSession(config=session_config)
else:
self.session_creator = NewSession(config=get_default_sess_config(0.4))
else:
self.session_creator = session_creator
# inputs & outputs
self.input_names = input_names
if self.input_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_vars]
raw_tensors = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_tensors]
self.output_names = output_names
assert_type(self.output_names, list)
assert_type(self.input_names, list)
......
......@@ -5,8 +5,6 @@
import tensorflow as tf
from ..utils import logger
from ..utils.naming import PREDICT_TOWER
from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph
......@@ -31,17 +29,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
config.model.build_graph(config.model.get_reused_placehdrs())
build_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config)
self.sess = config.session_creator.create_session()
config.session_init.init(self.sess)
input_vars = get_tensors_by_names(config.input_names)
input_tensors = get_tensors_by_names(config.input_names)
for k in towers:
output_vars = get_tensors_by_names(
['{}{}/'.format(PREDICT_TOWER, k) + n
output_tensors = get_tensors_by_names(
[TowerContext.get_predict_towre_name('', k) + '/' + n
for n in config.output_names])
self.predictors.append(OnlinePredictor(
input_vars, output_vars, config.return_input, self.sess))
input_tensors, output_tensors, config.return_input, self.sess))
def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
......@@ -57,7 +55,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor.
It runs different towers in parallel.
Its input is: [input[0] in tower[0], input[1] in tower[0], ...,
input[0] in tower[1], input[1] in tower[1], ...]
And same for the output.
"""
def __init__(self, config, towers):
......@@ -68,26 +68,25 @@ class DataParallelOfflinePredictor(OnlinePredictor):
"""
self.graph = tf.Graph()
with self.graph.as_default():
sess = tf.Session(config=config.session_config)
input_var_names = []
output_vars = []
for idx, k in enumerate(towers):
towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders(
prefix=towername + '-')
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False), \
tf.variable_scope(tf.get_variable_scope(),
reuse=True if idx > 0 else None):
config.model.build_graph(input_vars)
input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
input_names = []
output_tensors = []
def build_tower(k):
towername = TowerContext.get_predict_tower_name(k)
# inputs (placeholders) for this tower only
input_tensors = config.model.build_placeholders(prefix=towername + '/')
config.model.build_graph(input_tensors)
input_names.extend([t.name for t in input_tensors])
output_tensors.extend(get_tensors_by_names(
[towername + '/' + n
for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names)
build_prediction_graph(build_tower, towers)
input_tensors = get_tensors_by_names(input_names)
sess = config.session_creator.create_session()
config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__(
input_vars, output_vars, config.return_input, sess)
input_tensors, output_tensors, config.return_input, sess)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sesscreate.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
__all__ = ['NewSession', 'ReuseSession']
class NewSession(tf.train.SessionCreator):
def __init__(self, target='', graph=None, config=None):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
"""
self.target = target
self.config = config
self.graph = graph
def create_session(self):
return tf.Session(target=self.target, graph=self.graph, config=self.config)
class ReuseSession(tf.train.SessionCreator):
def __init__(self, sess):
"""
Args:
sess (tf.Session): the session to reuse
"""
self.sess = sess
def create_session(self):
return self.sess
......@@ -12,13 +12,13 @@ from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'SaverRestoreRelaxed',
__all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader']
class SessionInit(object):
""" Base class for utilities to initialize a session. """
""" Base class for utilities to initialize a (existing) session. """
def init(self, sess):
"""
Initialize a session
......@@ -44,17 +44,6 @@ class JustCurrentSession(SessionInit):
pass
class NewSession(SessionInit):
"""
Initialize global variables by their initializer.
"""
def _setup_graph(self):
self.op = tf.global_variables_initializer()
def _run_init(self, sess):
sess.run(self.op)
class CheckpointReaderAdapter(object):
"""
An adapter to work around old checkpoint format, where the keys are op
......@@ -207,15 +196,11 @@ class ChainInit(SessionInit):
to form a composition of models.
"""
def __init__(self, sess_inits, new_session=True):
def __init__(self, sess_inits):
"""
Args:
sess_inits (list[SessionInit]): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
"""
if new_session and not isinstance(sess_inits[0], NewSession):
sess_inits.insert(0, NewSession())
self.inits = sess_inits
def _init(self, sess):
......
......@@ -14,7 +14,7 @@ __all__ = ['PredictorFactory']
class PredictorFactory(object):
""" Make predictors for a trainer"""
""" Make predictors from a trainer."""
def __init__(self, trainer):
"""
......@@ -25,6 +25,7 @@ class PredictorFactory(object):
self.towers = trainer.config.predict_tower
assert isinstance(self.towers, list)
# TODO sess option
def get_predictor(self, input_names, output_names, tower):
"""
Args:
......@@ -48,11 +49,11 @@ class PredictorFactory(object):
return get_name_in_tower(name)
input_names = map(maybe_inside_tower, input_names)
raw_input_vars = get_tensors_by_names(input_names)
raw_input_tensors = get_tensors_by_names(input_names)
output_names = map(get_name_in_tower, output_names)
output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(raw_input_vars, output_vars)
output_tensors = get_tensors_by_names(output_names)
return OnlinePredictor(raw_input_tensors, output_tensors)
@memoized
def _build_predict_tower(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment