Commit 4ee1e735 authored by Yuxin Wu's avatar Yuxin Wu

PredictorFactory build tower by itself.

parent e839c50d
......@@ -3,58 +3,89 @@
# File: predictor_factory.py
import tensorflow as tf
# from ..tfutils.tower import TowerContext
from ..predict import (OnlinePredictor,
PredictorTowerBuilder)
from ..utils import logger
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection
from ..predict import OnlinePredictor
from ..utils.naming import TOWER_FREEZE_KEYS
__all__ = ['PredictorFactory']
# class PredictorTowerBuilder(object):
# def __init__(self, model):
# self._model = model
# self._towers = []
#
# def build(self, tower_name, device, input=None):
# with tf.device(device), TowerContext(tower_name, is_training=False):
# if input is None:
# input = self._model.get_reused_placehdrs()
# self._model.build_graph(input)
#
#
# SMART
class PredictorTowerHandle(object):
def __init__(self, tower_name, input_tensors):
self._tower_name = tower_name
self._input_tensors = input_tensors
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors]
def get_tensors(self, names):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[1]
if name in self._input_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return self._tower_name + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc` and cache them."""
def __init__(self, model, towers, vs_name):
"""
Args:
model (ModelDesc):
towers (list[int]): list of available gpu id
vs_name (str):
"""
self.model = model
self.towers = towers
self.vs_name = vs_name
assert isinstance(towers, list), towers
self._model = model
self._towers = towers
self._vs_name = vs_name
self._names_built = {}
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
self._tower_builder = PredictorTowerBuilder(fn)
assert isinstance(self.towers, list), self.towers
def build(self, tower_name, device, input=None):
logger.info("Building predictor graph {} on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built
with tf.device(device), \
TowerContext(tower_name, is_training=False), \
freeze_collection(TOWER_FREEZE_KEYS):
if input is None:
input = self._model.get_reused_placehdrs()
else:
input = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input
self._model.build_graph(input)
self._names_built[tower_name] = PredictorTowerHandle(tower_name, input)
return self._names_built[tower_name]
def has_built(self, tower_name):
return tower_name in self._names_built
def get_predictor(self, input_names, output_names, tower):
"""
Args:
tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower)
Returns:
an online predictor (which has to be used under a default session)
an online predictor (which has to be used under the default session)
"""
tower = self.towers[tower]
# just ensure the tower exists. won't rebuild (memoized)
with tf.variable_scope(self.vs_name, reuse=True):
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
in_tensors = get_tensor_fn(placeholder_names, input_names, tower)
out_tensors = get_tensor_fn(placeholder_names, output_names, tower)
tower = self._towers[tower]
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(max(tower, 0)) # XXX
# use a previously-built tower
# TODO conflict with inference runner??
if not self.has_built(tower_name):
with tf.variable_scope(self._vs_name, reuse=True):
handle = self.build(tower_name, device)
else:
handle = self._names_built[tower_name]
in_tensors = handle.get_tensors(input_names)
out_tensors = handle.get_tensors(output_names)
return OnlinePredictor(in_tensors, out_tensors)
......@@ -10,7 +10,8 @@ import six
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.naming import TOWER_FREEZE_KEYS
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils.common import get_tensors_by_names, get_op_tensor_name
from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection
__all__ = ['PredictorBase', 'AsyncPredictorBase',
......
......@@ -47,7 +47,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def get_predictor(self, n):
"""
Returns:
PredictorBase: the nth predictor on the nth tower.
OnlinePredictor: the nth predictor on the nth tower.
"""
l = len(self.predictors)
if n >= l:
......@@ -57,7 +57,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def get_predictors(self):
"""
Returns:
list[PredictorBase]: a list of predictor
list[OnlinePredictor]: a list of predictor
"""
return self.predictors
......
......@@ -49,11 +49,14 @@ def create_image_summary(name, val):
val = val.astype('uint8')
s = tf.Summary()
for k in range(n):
arr = val[k]
if arr.shape[2] == 1: # scipy doesn't accept (h,w,1)
arr = arr[:, :, 0]
tag = name if n == 1 else '{}/{}'.format(name, k)
buf = io.BytesIO()
# scipy assumes RGB
scipy.misc.toimage(val[k]).save(buf, format='png')
scipy.misc.toimage(arr).save(buf, format='png')
img = tf.Summary.Image()
img.height = h
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment