Commit e839c50d authored by Yuxin Wu's avatar Yuxin Wu

move PredictorFactory to graph_builder

parent efe3dfb5
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: predict.py # File: predictor_factory.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
# from ..tfutils.tower import TowerContext
from ..predict import (OnlinePredictor, from ..predict import (OnlinePredictor,
PredictorTowerBuilder) PredictorTowerBuilder)
__all__ = ['PredictorFactory'] __all__ = ['PredictorFactory']
# class PredictorTowerBuilder(object):
# def __init__(self, model):
# self._model = model
# self._towers = []
#
# def build(self, tower_name, device, input=None):
# with tf.device(device), TowerContext(tower_name, is_training=False):
# if input is None:
# input = self._model.get_reused_placehdrs()
# self._model.build_graph(input)
#
#
# SMART
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors from a trainer.""" """ Make predictors from :class:`ModelDesc` and cache them."""
def __init__(self, trainer): def __init__(self, model, towers, vs_name):
""" """
Args: Args:
towers (list[int]): list of gpu id towers (list[int]): list of available gpu id
""" """
self.model = trainer.model self.model = model
self.towers = trainer.config.predict_tower self.towers = towers
self.vs_name = trainer.vs_name_for_predictor self.vs_name = vs_name
def fn(_): def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs()) self.model.build_graph(self.model.get_reused_placehdrs())
......
...@@ -14,22 +14,26 @@ _CurrentTowerContext = None ...@@ -14,22 +14,26 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """ """ A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None, def __init__(self, tower_name, is_training=None, index=0, vs_name=''):
index=0, vs_name=''):
""" """
Args: Args:
tower_name (str): The name scope of the tower. Currently used tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or '' values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower index (int): index of this tower.
vs_name (str): Open a variable scope with this name, if given. vs_name (str): Open a variable scope with this name, if given.
""" """
self._name = tower_name self._name = tower_name
if is_training is None: if is_training is None:
# TODO remove this
is_training = not self._name.startswith(PREDICT_TOWER) is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = bool(is_training) self._is_training = bool(is_training)
if not self._is_training:
# TODO ugly
assert index == 0 and vs_name == '', "vs_name and index are meaningless in prediction!"
self._index = int(index) self._index = int(index)
self._vs_name = str(vs_name) self._vs_name = str(vs_name)
...@@ -40,10 +44,6 @@ class TowerContext(object): ...@@ -40,10 +44,6 @@ class TowerContext(object):
def is_main_training_tower(self): def is_main_training_tower(self):
return self.is_training and self._index == 0 return self.is_training and self._index == 0
@property
def is_main_tower(self):
return self._index == 0
@property @property
def is_training(self): def is_training(self):
return self._is_training return self._is_training
...@@ -113,11 +113,15 @@ class TowerContext(object): ...@@ -113,11 +113,15 @@ class TowerContext(object):
if self.is_training: if self.is_training:
reuse = self._index > 0 reuse = self._index > 0
if reuse is True: if reuse is True:
# clear old name_scope and re-enter the current variable_scope # clear old name_scope (due to the existing variable_scope)
# and re-enter the current variable_scope
self._ctxs.append(tf.name_scope(None)) self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope( self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True)) tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good) else:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
for c in self._ctxs: for c in self._ctxs:
c.__enter__() c.__enter__()
......
...@@ -10,7 +10,7 @@ from six.moves import range ...@@ -10,7 +10,7 @@ from six.moves import range
import tensorflow as tf import tensorflow as tf
from .predict import PredictorFactory from ..graph_builder.predictor_factory import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
...@@ -217,6 +217,7 @@ class Trainer(object): ...@@ -217,6 +217,7 @@ class Trainer(object):
""" """
The variable scope name a predictor should be built in. The variable scope name a predictor should be built in.
""" """
# TODO graphbuilder knows it
return "" return ""
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
...@@ -229,7 +230,8 @@ class Trainer(object): ...@@ -229,7 +230,8 @@ class Trainer(object):
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
if not hasattr(self, '_predictor_factory'): if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(self) self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
nr_tower = len(self.config.predict_tower) nr_tower = len(self.config.predict_tower)
if nr_tower < tower: if nr_tower < tower:
logger.warn( logger.warn(
......
...@@ -17,7 +17,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer', ...@@ -17,7 +17,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
class FeedfreeTrainerBase(Trainer): class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster) """ A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``config.data`` to be a :class:`FeedfreeInput`.
""" """
@deprecated("Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)") @deprecated("Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)")
......
...@@ -54,7 +54,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -54,7 +54,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
Args: Args:
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers. devices: a list of devices to be used. By default will use GPUs in ``towers``.
var_strategy (str): 'shared' or 'replicated' var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use. vs_names (list[str]): list of variable scope names to use.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment