Commit 6c68f8aa authored by Yuxin Wu's avatar Yuxin Wu

rename some variables in inferencerunner.

parent a47c9980
...@@ -8,13 +8,12 @@ from collections import namedtuple ...@@ -8,13 +8,12 @@ from collections import namedtuple
import six import six
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger, get_tqdm
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import freeze_collection
from ..tfutils import TowerContext from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph from ..predict import PredictorTowerBuilder
from .base import Triggerable from .base import Triggerable
from .inference import Inferencer from .inference import Inferencer
...@@ -22,7 +21,7 @@ from .inference import Inferencer ...@@ -22,7 +21,7 @@ from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner'] __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
class OutputTensorDispatcer(object): class OutputTensorDispatcher(object):
def __init__(self): def __init__(self):
self._names = [] self._names = []
self._idxs = [] self._idxs = []
...@@ -71,7 +70,7 @@ class InferenceRunner(Triggerable): ...@@ -71,7 +70,7 @@ class InferenceRunner(Triggerable):
_IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) _IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None): def __init__(self, ds, infs, input_tensor_names=None):
""" """
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. ds (DataFlow): the DataFlow to run inferencer on.
...@@ -87,16 +86,16 @@ class InferenceRunner(Triggerable): ...@@ -87,16 +86,16 @@ class InferenceRunner(Triggerable):
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
self.input_tensors = input_tensors # names actually self.input_names = input_tensor_names # names actually
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # these are all tensor names self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name self._find_output_tensors() # may be either tensor name or op name
self.predictor = self.trainer.get_predictor( self.predictor = self.trainer.get_predictor(
self.input_tensors, self.output_tensors) self.input_names, self.output_names)
def _find_input_tensors(self): def _find_input_tensors(self):
if self.input_tensors is None: if self.input_names is None:
input_vars = self.trainer.model.get_reused_placehdrs() input_vars = self.trainer.model.get_reused_placehdrs()
# TODO even if it works here, sparse still is unavailable # TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse # because get_tensor_by_name doesn't work for sparse
...@@ -105,27 +104,27 @@ class InferenceRunner(Triggerable): ...@@ -105,27 +104,27 @@ class InferenceRunner(Triggerable):
if isinstance(x, tf.SparseTensor): if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0] return x.op.name.split('/')[0]
return x.name return x.name
self.input_tensors = [get_name(x) for x in input_vars] self.input_names = [get_name(x) for x in input_vars]
def _find_output_tensors(self): def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer() dispatcher = OutputTensorDispatcher()
for inf in self.infs: for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors()) dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names() all_names = dispatcher.get_all_names()
IOTensor = InferenceRunner._IOTensor IOTensor = InferenceRunner._IOTensor
self.output_tensors = list(filter( self.output_names = list(filter(
lambda x: x not in self.input_tensors, all_names)) lambda x: x not in self.input_names, all_names))
def find_tensors(names): def find_tensors(names):
ret = [] ret = []
for name in names: for name in names:
if name in self.input_tensors: if name in self.input_names:
ret.append(IOTensor(self.input_tensors.index(name), False)) ret.append(IOTensor(self.input_names.index(name), False))
else: else:
ret.append(IOTensor(self.output_tensors.index(name), True)) ret.append(IOTensor(self.output_names.index(name), True))
return ret return ret
self.inf_to_tensors = [find_tensors(t) for t in dispatcer.get_names_for_each_entry()] self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
# list of list of IOTensor # list of list of IOTensor
def _trigger(self): def _trigger(self):
...@@ -183,14 +182,11 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -183,14 +182,11 @@ class FeedfreeInferenceRunner(Triggerable):
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
# TODO reuse predictor code # TODO can we reuse predictor factory?
# overwrite the FeedfreeInferenceRunner name scope with tf.variable_scope(tf.get_variable_scope(), reuse=True):
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_): def fn(_):
self.trainer.model.build_graph(self._input_tensors) self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0], prefix=self._prefix) PredictorTowerBuilder(fn, self._prefix).build(0)
self._tower_prefix = TowerContext.get_predict_tower_name(0, self._prefix) self._tower_prefix = TowerContext.get_predict_tower_name(0, self._prefix)
self._find_output_tensors() self._find_output_tensors()
...@@ -220,16 +216,16 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -220,16 +216,16 @@ class FeedfreeInferenceRunner(Triggerable):
def _find_output_tensors(self): def _find_output_tensors(self):
# TODO doesn't support output an input tensor # TODO doesn't support output an input tensor
dispatcer = OutputTensorDispatcer() dispatcher = OutputTensorDispatcher()
for inf in self.infs: for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors()) dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names() all_names = dispatcher.get_all_names()
G = tf.get_default_graph() G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name( self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names] self._tower_prefix + '/' + n) for n in all_names]
# list of list of id # list of list of id
self.inf_to_idxs = dispatcer.get_idx_for_each_entry() self.inf_to_idxs = dispatcher.get_idx_for_each_entry()
def _trigger(self): def _trigger(self):
sess = tf.get_default_session() sess = tf.get_default_session()
......
...@@ -129,8 +129,10 @@ class Trainer(object): ...@@ -129,8 +129,10 @@ class Trainer(object):
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph() self.config.session_init._setup_graph()
def after_init(_, __): def after_init(scaffold, sess):
logger.info("Graph variables initialized.") logger.info("Graph variables initialized.")
self.config.session_init._run_init(sess)
scaffold = tf.train.Scaffold( scaffold = tf.train.Scaffold(
init_op=tf.global_variables_initializer(), init_op=tf.global_variables_initializer(),
init_fn=after_init) init_fn=after_init)
...@@ -140,9 +142,7 @@ class Trainer(object): ...@@ -140,9 +142,7 @@ class Trainer(object):
scaffold=scaffold, config=self.config.session_config), scaffold=scaffold, config=self.config.session_config),
hooks=self.config.callbacks.get_hooks()) hooks=self.config.callbacks.get_hooks())
self.hooked_sess = self.monitored_sess # just create an alias self.hooked_sess = self.monitored_sess # just create an alias
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
self.config.session_init._run_init(self.sess)
@abstractmethod @abstractmethod
def _setup(self): def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment