Commit fb501e66 authored by Yuxin Wu's avatar Yuxin Wu

InferenceRunnerBase uses PredictorFactory to build graph

parent 4ee1e735
...@@ -87,12 +87,11 @@ class InferenceRunnerBase(Callback): ...@@ -87,12 +87,11 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
def fn(_): tower_name = TowerContext.get_predict_tower_name(tower_id, prefix=self._prefix)
self.trainer.model.build_graph(self._input_source) with tf.variable_scope(tf.get_variable_scope(), reuse=True):
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True): self._tower_handle = self.trainer.predictor_factory.build(tower_name, device, self._input_source)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks() cbs = self._input_source.get_callbacks()
...@@ -102,11 +101,6 @@ class InferenceRunnerBase(Callback): ...@@ -102,11 +101,6 @@ class InferenceRunnerBase(Callback):
self._hooks.extend(self._extra_hooks) self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
def _get_tensors_maybe_in_tower(self, names):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix)
@abstractmethod @abstractmethod
def _build_hook(self, inf): def _build_hook(self, inf):
pass pass
...@@ -142,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -142,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase):
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
fetches = self._get_tensors_maybe_in_tower(out_names) fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
...@@ -170,7 +164,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -170,7 +164,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
ret = [] ret = []
for name in out_names: for name in out_names:
assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!" assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!"
ret.append(self._get_tensors_maybe_in_tower([name])[0]) ret.append(self._tower_handle.get_tensors([name])[0])
return InferencerToHook(inf, ret) return InferencerToHook(inf, ret)
......
...@@ -33,7 +33,7 @@ class PredictorTowerHandle(object): ...@@ -33,7 +33,7 @@ class PredictorTowerHandle(object):
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc` and cache them.""" """ Make predictors from :class:`ModelDesc`."""
def __init__(self, model, towers, vs_name): def __init__(self, model, towers, vs_name):
""" """
......
...@@ -167,7 +167,6 @@ class PredictorTowerBuilder(object): ...@@ -167,7 +167,6 @@ class PredictorTowerBuilder(object):
tower (int): the tower will be built on device '/gpu:{tower}', or tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1. '/cpu:0' if tower is -1.
""" """
toweridx = max(tower, 0) # if CPU, named the tower as 0
towername = TowerContext.get_predict_tower_name(tower, self._prefix) towername = TowerContext.get_predict_tower_name(tower, self._prefix)
if self._prefix: if self._prefix:
msg = "Building predictor graph {} on gpu={} with prefix='{}' ...".format( msg = "Building predictor graph {} on gpu={} with prefix='{}' ...".format(
...@@ -180,7 +179,7 @@ class PredictorTowerBuilder(object): ...@@ -180,7 +179,7 @@ class PredictorTowerBuilder(object):
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(TOWER_FREEZE_KEYS), \ freeze_collection(TOWER_FREEZE_KEYS), \
tf.device(device), \ tf.device(device), \
TowerContext(towername, is_training=False, index=toweridx): TowerContext(towername, is_training=False):
self._fn(tower) self._fn(tower)
# useful only when the placeholders don't have tower prefix # useful only when the placeholders don't have tower prefix
......
...@@ -211,7 +211,7 @@ class Trainer(object): ...@@ -211,7 +211,7 @@ class Trainer(object):
self._callbacks.after_train() self._callbacks.after_train()
self.hooked_sess.close() self.hooked_sess.close()
# Predictor related methods: TODO # Predictor related methods:
@property @property
def vs_name_for_predictor(self): def vs_name_for_predictor(self):
""" """
...@@ -229,16 +229,21 @@ class Trainer(object): ...@@ -229,16 +229,21 @@ class Trainer(object):
Returns: Returns:
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
if not hasattr(self, '_predictor_factory'): # TODO move the logic to factory?
self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
nr_tower = len(self.config.predict_tower) nr_tower = len(self.config.predict_tower)
if nr_tower < tower: if nr_tower < tower:
logger.warn( logger.warn(
"Requested the {}th predictor but only have {} predict towers! " "Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower)) "Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower))
tower = tower % nr_tower tower = tower % nr_tower
return self._predictor_factory.get_predictor(input_names, output_names, tower) return self.predictor_factory.get_predictor(input_names, output_names, tower)
@property
def predictor_factory(self):
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
return self._predictor_factory
def get_predictors(self, input_names, output_names, n): def get_predictors(self, input_names, output_names, n):
""" Return n predictors. """ """ Return n predictors. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment