Commit 6c68f8aa authored by Yuxin Wu's avatar Yuxin Wu

rename some variables in inferencerunner.

parent a47c9980
......@@ -8,13 +8,12 @@ from collections import namedtuple
import six
from six.moves import zip, range
from ..utils import logger, get_tqdm
from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import freeze_collection
from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph
from ..predict import PredictorTowerBuilder
from .base import Triggerable
from .inference import Inferencer
......@@ -22,7 +21,7 @@ from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
class OutputTensorDispatcer(object):
class OutputTensorDispatcher(object):
def __init__(self):
self._names = []
self._idxs = []
......@@ -71,7 +70,7 @@ class InferenceRunner(Triggerable):
_IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
def __init__(self, ds, infs, input_tensor_names=None):
"""
Args:
ds (DataFlow): the DataFlow to run inferencer on.
......@@ -87,16 +86,16 @@ class InferenceRunner(Triggerable):
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), v
self.input_tensors = input_tensors # names actually
self.input_names = input_tensor_names # names actually
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.predictor = self.trainer.get_predictor(
self.input_tensors, self.output_tensors)
self.input_names, self.output_names)
def _find_input_tensors(self):
if self.input_tensors is None:
if self.input_names is None:
input_vars = self.trainer.model.get_reused_placehdrs()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
......@@ -105,27 +104,27 @@ class InferenceRunner(Triggerable):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_tensors = [get_name(x) for x in input_vars]
self.input_names = [get_name(x) for x in input_vars]
def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
dispatcher = OutputTensorDispatcher()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names()
IOTensor = InferenceRunner._IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
self.output_names = list(filter(
lambda x: x not in self.input_names, all_names))
def find_tensors(names):
ret = []
for name in names:
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
if name in self.input_names:
ret.append(IOTensor(self.input_names.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
ret.append(IOTensor(self.output_names.index(name), True))
return ret
self.inf_to_tensors = [find_tensors(t) for t in dispatcer.get_names_for_each_entry()]
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
# list of list of IOTensor
def _trigger(self):
......@@ -183,14 +182,11 @@ class FeedfreeInferenceRunner(Triggerable):
def _setup_graph(self):
self._find_input_tensors() # tensors
# TODO reuse predictor code
# overwrite the FeedfreeInferenceRunner name scope
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
# TODO can we reuse predictor factory?
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0], prefix=self._prefix)
PredictorTowerBuilder(fn, self._prefix).build(0)
self._tower_prefix = TowerContext.get_predict_tower_name(0, self._prefix)
self._find_output_tensors()
......@@ -220,16 +216,16 @@ class FeedfreeInferenceRunner(Triggerable):
def _find_output_tensors(self):
# TODO doesn't support output an input tensor
dispatcer = OutputTensorDispatcer()
dispatcher = OutputTensorDispatcher()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names()
G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names]
# list of list of id
self.inf_to_idxs = dispatcer.get_idx_for_each_entry()
self.inf_to_idxs = dispatcher.get_idx_for_each_entry()
def _trigger(self):
sess = tf.get_default_session()
......
......@@ -129,8 +129,10 @@ class Trainer(object):
self.config.callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph()
def after_init(_, __):
def after_init(scaffold, sess):
logger.info("Graph variables initialized.")
self.config.session_init._run_init(sess)
scaffold = tf.train.Scaffold(
init_op=tf.global_variables_initializer(),
init_fn=after_init)
......@@ -140,9 +142,7 @@ class Trainer(object):
scaffold=scaffold, config=self.config.session_config),
hooks=self.config.callbacks.get_hooks())
self.hooked_sess = self.monitored_sess # just create an alias
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
self.config.session_init._run_init(self.sess)
@abstractmethod
def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment