Commit fb501e66 authored by Yuxin Wu's avatar Yuxin Wu

InferenceRunnerBase uses PredictorFactory to build graph

parent 4ee1e735
......@@ -87,12 +87,11 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc())
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
def fn(_):
self.trainer.model.build_graph(self._input_source)
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(tower_id, prefix=self._prefix)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
......@@ -102,11 +101,6 @@ class InferenceRunnerBase(Callback):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
def _get_tensors_maybe_in_tower(self, names):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix)
@abstractmethod
def _build_hook(self, inf):
pass
......@@ -142,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase):
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
fetches = self._get_tensors_maybe_in_tower(out_names)
fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches)
......@@ -170,7 +164,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
ret = []
for name in out_names:
assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!"
ret.append(self._get_tensors_maybe_in_tower([name])[0])
ret.append(self._tower_handle.get_tensors([name])[0])
return InferencerToHook(inf, ret)
......
......@@ -33,7 +33,7 @@ class PredictorTowerHandle(object):
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc` and cache them."""
""" Make predictors from :class:`ModelDesc`."""
def __init__(self, model, towers, vs_name):
"""
......
......@@ -167,7 +167,6 @@ class PredictorTowerBuilder(object):
tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1.
"""
toweridx = max(tower, 0) # if CPU, named the tower as 0
towername = TowerContext.get_predict_tower_name(tower, self._prefix)
if self._prefix:
msg = "Building predictor graph {} on gpu={} with prefix='{}' ...".format(
......@@ -180,7 +179,7 @@ class PredictorTowerBuilder(object):
with tf.name_scope(None), \
freeze_collection(TOWER_FREEZE_KEYS), \
tf.device(device), \
TowerContext(towername, is_training=False, index=toweridx):
TowerContext(towername, is_training=False):
self._fn(tower)
# useful only when the placeholders don't have tower prefix
......
......@@ -211,7 +211,7 @@ class Trainer(object):
self._callbacks.after_train()
self.hooked_sess.close()
# Predictor related methods: TODO
# Predictor related methods:
@property
def vs_name_for_predictor(self):
"""
......@@ -229,16 +229,21 @@ class Trainer(object):
Returns:
an :class:`OnlinePredictor`.
"""
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
# TODO move the logic to factory?
nr_tower = len(self.config.predict_tower)
if nr_tower < tower:
logger.warn(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower))
tower = tower % nr_tower
return self._predictor_factory.get_predictor(input_names, output_names, tower)
return self.predictor_factory.get_predictor(input_names, output_names, tower)
@property
def predictor_factory(self):
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
return self._predictor_factory
def get_predictors(self, input_names, output_names, n):
""" Return n predictors. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment