Commit fe2b4f97 authored by Yuxin Wu's avatar Yuxin Wu

code clean-up in predict/

parent 8df83a93
...@@ -255,7 +255,6 @@ def run_image(model, sess_init, inputs): ...@@ -255,7 +255,6 @@ def run_image(model, sess_init, inputs):
pred_config = PredictConfig( pred_config = PredictConfig(
model=model, model=model,
session_init=sess_init, session_init=sess_init,
session_config=get_default_sess_config(0.9),
input_names=['input'], input_names=['input'],
output_names=['output'] output_names=['output']
) )
......
...@@ -125,7 +125,6 @@ def run_image(model, sess_init, inputs): ...@@ -125,7 +125,6 @@ def run_image(model, sess_init, inputs):
pred_config = PredictConfig( pred_config = PredictConfig(
model=model, model=model,
session_init=sess_init, session_init=sess_init,
session_config=get_default_sess_config(0.9),
input_names=['input'], input_names=['input'],
output_names=['output'] output_names=['output']
) )
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six import six
from ..tfutils.common import get_op_or_tensor_by_name, get_global_step_value from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Callback(object): class Callback(object):
""" Base class for all callbacks """ Base class for all callbacks.
Attributes: Attributes:
epoch_num(int): the number of the current epoch. epoch_num(int): the number of the current epoch.
...@@ -50,7 +50,6 @@ class Callback(object): ...@@ -50,7 +50,6 @@ class Callback(object):
pass pass
def before_train(self): def before_train(self):
self._starting_step = get_global_step_value()
self._before_train() self._before_train()
def _before_train(self): def _before_train(self):
......
...@@ -154,6 +154,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -154,6 +154,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
yield dp yield dp
except zmq.ContextTerminated: except zmq.ContextTerminated:
logger.info("ContextTerminated in Master Prefetch Process") logger.info("ContextTerminated in Master Prefetch Process")
return
except: except:
raise raise
......
...@@ -88,7 +88,8 @@ class AsyncPredictorBase(PredictorBase): ...@@ -88,7 +88,8 @@ class AsyncPredictorBase(PredictorBase):
class OnlinePredictor(PredictorBase): class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session. """ """ A predictor which directly use an existing session and given tensors.
"""
def __init__(self, input_tensors, output_tensors, def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None): return_input=False, sess=None):
...@@ -131,13 +132,13 @@ class OfflinePredictor(OnlinePredictor): ...@@ -131,13 +132,13 @@ class OfflinePredictor(OnlinePredictor):
with TowerContext('', False): with TowerContext('', False):
config.model.build_graph(input_placehdrs) config.model.build_graph(input_placehdrs)
input_vars = get_tensors_by_names(config.input_names) input_tensors = get_tensors_by_names(config.input_names)
output_vars = get_tensors_by_names(config.output_names) output_tensors = get_tensors_by_names(config.output_names)
sess = tf.Session(config=config.session_config) sess = config.session_creator.create_session()
config.session_init.init(sess) config.session_init.init(sess)
super(OfflinePredictor, self).__init__( super(OfflinePredictor, self).__init__(
input_vars, output_vars, config.return_input, sess) input_tensors, output_tensors, config.return_input, sess)
def get_predict_func(config): def get_predict_func(config):
...@@ -149,6 +150,8 @@ def get_predict_func(config): ...@@ -149,6 +150,8 @@ def get_predict_func(config):
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''): def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
""" """
Build graph on each tower.
Args: Args:
build_tower_fn: a function that will be called inside each tower, build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument. taking tower id as the argument.
......
...@@ -8,7 +8,7 @@ import six ...@@ -8,7 +8,7 @@ import six
from six.moves import queue, range from six.moves import queue, range
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger, deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
...@@ -27,6 +27,7 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -27,6 +27,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
config (PredictConfig): the config to use. config (PredictConfig): the config to use.
""" """
super(MultiProcessPredictWorker, self).__init__() super(MultiProcessPredictWorker, self).__init__()
self.name = "MultiProcessPredictWorker-{}".format(idx)
self.idx = idx self.idx = idx
self.config = config self.config = config
...@@ -76,6 +77,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -76,6 +77,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class PredictorWorkerThread(StoppableThread, ShareSessionThread): class PredictorWorkerThread(StoppableThread, ShareSessionThread):
def __init__(self, queue, pred_func, id, batch_size=5): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.name = "PredictorWorkerThread-{}".format(id)
self.queue = queue self.queue = queue
self.func = pred_func self.func = pred_func
self.daemon = True self.daemon = True
...@@ -112,22 +114,20 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread): ...@@ -112,22 +114,20 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
for k in range(nr_input_var): for k in range(nr_input_var):
batched[k].append(inp[k]) batched[k].append(inp[k])
futures.append(f) futures.append(f)
cnt = 1 while len(futures) < self.batch_size:
while cnt < self.batch_size:
try: try:
inp, f = self.queue.get_nowait() inp, f = self.queue.get_nowait()
for k in range(nr_input_var): for k in range(nr_input_var):
batched[k].append(inp[k]) batched[k].append(inp[k])
futures.append(f) futures.append(f)
except queue.Empty: except queue.Empty:
break break # do not wait
cnt += 1
return batched, futures return batched, futures
class MultiThreadAsyncPredictor(AsyncPredictorBase): class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" """
An multithread online async predictor which runs a list of PredictorBase. An multithread online async predictor which runs a list of OnlinePredictor.
It would do an extra batching internally. It would do an extra batching internally.
""" """
...@@ -164,6 +164,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -164,6 +164,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
for t in self.threads: for t in self.threads:
t.start() t.start()
@deprecated("Use 'start()' instead!", "2017-03-11")
def run(self): # temporarily for back-compatibility def run(self): # temporarily for back-compatibility
self.start() self.start()
......
...@@ -5,50 +5,60 @@ ...@@ -5,50 +5,60 @@
import six import six
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import log_deprecated
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSession
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
class PredictConfig(object): class PredictConfig(object):
def __init__(self, model, session_init=None, def __init__(self, model,
session_config=get_default_sess_config(0.4), session_creator=None,
session_init=None,
session_config=None,
input_names=None, input_names=None,
output_names=None, output_names=None,
return_input=False): return_input=False):
""" """
Args: Args:
model (ModelDesc): the model to use. model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`.
session_init (SessionInit): how to initialize variables of the session. session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing. Defaults to do nothing.
session_config]
input_names (list): a list of input tensor names. Defaults to all input_names (list): a list of input tensor names. Defaults to all
inputs of the model. inputs of the model.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph. tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`. return_input: same as in :attr:`PredictorBase.return_input`.
""" """
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.model = model self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = session_config
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
self.session_init = session_init self.session_init = session_init
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
if session_creator is None:
if session_config is not None:
log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20")
self.session_creator = NewSession(config=session_config)
else:
self.session_creator = NewSession(config=get_default_sess_config(0.4))
else:
self.session_creator = session_creator
# inputs & outputs # inputs & outputs
self.input_names = input_names self.input_names = input_names
if self.input_names is None: if self.input_names is None:
# neither options is set, assume all inputs # neither options is set, assume all inputs
raw_vars = self.model.get_inputs_desc() raw_tensors = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_vars] self.input_names = [k.name for k in raw_tensors]
self.output_names = output_names self.output_names = output_names
assert_type(self.output_names, list) assert_type(self.output_names, list)
assert_type(self.input_names, list) assert_type(self.input_names, list)
......
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger
from ..utils.naming import PREDICT_TOWER
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph from .base import OnlinePredictor, build_prediction_graph
...@@ -31,17 +29,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -31,17 +29,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
config.model.build_graph(config.model.get_reused_placehdrs()) config.model.build_graph(config.model.get_reused_placehdrs())
build_prediction_graph(fn, towers) build_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config) self.sess = config.session_creator.create_session()
config.session_init.init(self.sess) config.session_init.init(self.sess)
input_vars = get_tensors_by_names(config.input_names) input_tensors = get_tensors_by_names(config.input_names)
for k in towers: for k in towers:
output_vars = get_tensors_by_names( output_tensors = get_tensors_by_names(
['{}{}/'.format(PREDICT_TOWER, k) + n [TowerContext.get_predict_towre_name('', k) + '/' + n
for n in config.output_names]) for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
input_vars, output_vars, config.return_input, self.sess)) input_tensors, output_tensors, config.return_input, self.sess))
def _do_call(self, dp): def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface # use the first tower for compatible PredictorBase interface
...@@ -57,7 +55,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -57,7 +55,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class DataParallelOfflinePredictor(OnlinePredictor): class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor. """ A data-parallel predictor.
It runs different towers in parallel. Its input is: [input[0] in tower[0], input[1] in tower[0], ...,
input[0] in tower[1], input[1] in tower[1], ...]
And same for the output.
""" """
def __init__(self, config, towers): def __init__(self, config, towers):
...@@ -68,26 +68,25 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -68,26 +68,25 @@ class DataParallelOfflinePredictor(OnlinePredictor):
""" """
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
sess = tf.Session(config=config.session_config) input_names = []
input_var_names = [] output_tensors = []
output_vars = []
for idx, k in enumerate(towers): def build_tower(k):
towername = PREDICT_TOWER + str(k) towername = TowerContext.get_predict_tower_name(k)
input_vars = config.model.build_placeholders( # inputs (placeholders) for this tower only
prefix=towername + '-') input_tensors = config.model.build_placeholders(prefix=towername + '/')
logger.info( config.model.build_graph(input_tensors)
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ input_names.extend([t.name for t in input_tensors])
TowerContext(towername, is_training=False), \ output_tensors.extend(get_tensors_by_names(
tf.variable_scope(tf.get_variable_scope(),
reuse=True if idx > 0 else None):
config.model.build_graph(input_vars)
input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
[towername + '/' + n [towername + '/' + n
for n in config.output_names])) for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names) build_prediction_graph(build_tower, towers)
input_tensors = get_tensors_by_names(input_names)
sess = config.session_creator.create_session()
config.session_init.init(sess) config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__( super(DataParallelOfflinePredictor, self).__init__(
input_vars, output_vars, config.return_input, sess) input_tensors, output_tensors, config.return_input, sess)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sesscreate.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
__all__ = ['NewSession', 'ReuseSession']
class NewSession(tf.train.SessionCreator):
def __init__(self, target='', graph=None, config=None):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
"""
self.target = target
self.config = config
self.graph = graph
def create_session(self):
return tf.Session(target=self.target, graph=self.graph, config=self.config)
class ReuseSession(tf.train.SessionCreator):
def __init__(self, sess):
"""
Args:
sess (tf.Session): the session to reuse
"""
self.sess = sess
def create_session(self):
return self.sess
...@@ -12,13 +12,13 @@ from .common import get_op_tensor_name ...@@ -12,13 +12,13 @@ from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname, from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path) is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'SaverRestoreRelaxed', __all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader'] 'JustCurrentSession', 'get_model_loader']
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a session. """ """ Base class for utilities to initialize a (existing) session. """
def init(self, sess): def init(self, sess):
""" """
Initialize a session Initialize a session
...@@ -44,17 +44,6 @@ class JustCurrentSession(SessionInit): ...@@ -44,17 +44,6 @@ class JustCurrentSession(SessionInit):
pass pass
class NewSession(SessionInit):
"""
Initialize global variables by their initializer.
"""
def _setup_graph(self):
self.op = tf.global_variables_initializer()
def _run_init(self, sess):
sess.run(self.op)
class CheckpointReaderAdapter(object): class CheckpointReaderAdapter(object):
""" """
An adapter to work around old checkpoint format, where the keys are op An adapter to work around old checkpoint format, where the keys are op
...@@ -207,15 +196,11 @@ class ChainInit(SessionInit): ...@@ -207,15 +196,11 @@ class ChainInit(SessionInit):
to form a composition of models. to form a composition of models.
""" """
def __init__(self, sess_inits, new_session=True): def __init__(self, sess_inits):
""" """
Args: Args:
sess_inits (list[SessionInit]): list of :class:`SessionInit` instances. sess_inits (list[SessionInit]): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
""" """
if new_session and not isinstance(sess_inits[0], NewSession):
sess_inits.insert(0, NewSession())
self.inits = sess_inits self.inits = sess_inits
def _init(self, sess): def _init(self, sess):
......
...@@ -14,7 +14,7 @@ __all__ = ['PredictorFactory'] ...@@ -14,7 +14,7 @@ __all__ = ['PredictorFactory']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors from a trainer."""
def __init__(self, trainer): def __init__(self, trainer):
""" """
...@@ -25,6 +25,7 @@ class PredictorFactory(object): ...@@ -25,6 +25,7 @@ class PredictorFactory(object):
self.towers = trainer.config.predict_tower self.towers = trainer.config.predict_tower
assert isinstance(self.towers, list) assert isinstance(self.towers, list)
# TODO sess option
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" """
Args: Args:
...@@ -48,11 +49,11 @@ class PredictorFactory(object): ...@@ -48,11 +49,11 @@ class PredictorFactory(object):
return get_name_in_tower(name) return get_name_in_tower(name)
input_names = map(maybe_inside_tower, input_names) input_names = map(maybe_inside_tower, input_names)
raw_input_vars = get_tensors_by_names(input_names) raw_input_tensors = get_tensors_by_names(input_names)
output_names = map(get_name_in_tower, output_names) output_names = map(get_name_in_tower, output_names)
output_vars = get_tensors_by_names(output_names) output_tensors = get_tensors_by_names(output_names)
return OnlinePredictor(raw_input_vars, output_vars) return OnlinePredictor(raw_input_tensors, output_tensors)
@memoized @memoized
def _build_predict_tower(self): def _build_predict_tower(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment