Commit e839c50d authored by Yuxin Wu's avatar Yuxin Wu

move PredictorFactory to graph_builder

parent efe3dfb5
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# File: predictor_factory.py
import tensorflow as tf
# from ..tfutils.tower import TowerContext
from ..predict import (OnlinePredictor,
PredictorTowerBuilder)
__all__ = ['PredictorFactory']
# class PredictorTowerBuilder(object):
# def __init__(self, model):
# self._model = model
# self._towers = []
#
# def build(self, tower_name, device, input=None):
# with tf.device(device), TowerContext(tower_name, is_training=False):
# if input is None:
# input = self._model.get_reused_placehdrs()
# self._model.build_graph(input)
#
#
# SMART
class PredictorFactory(object):
""" Make predictors from a trainer."""
""" Make predictors from :class:`ModelDesc` and cache them."""
def __init__(self, trainer):
def __init__(self, model, towers, vs_name):
"""
Args:
towers (list[int]): list of gpu id
towers (list[int]): list of available gpu id
"""
self.model = trainer.model
self.towers = trainer.config.predict_tower
self.vs_name = trainer.vs_name_for_predictor
self.model = model
self.towers = towers
self.vs_name = vs_name
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
......
......@@ -14,22 +14,26 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None,
index=0, vs_name=''):
def __init__(self, tower_name, is_training=None, index=0, vs_name=''):
"""
Args:
tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower
index (int): index of this tower.
vs_name (str): Open a variable scope with this name, if given.
"""
self._name = tower_name
if is_training is None:
# TODO remove this
is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = bool(is_training)
if not self._is_training:
# TODO ugly
assert index == 0 and vs_name == '', "vs_name and index are meaningless in prediction!"
self._index = int(index)
self._vs_name = str(vs_name)
......@@ -40,10 +44,6 @@ class TowerContext(object):
def is_main_training_tower(self):
return self.is_training and self._index == 0
@property
def is_main_tower(self):
return self._index == 0
@property
def is_training(self):
return self._is_training
......@@ -113,11 +113,15 @@ class TowerContext(object):
if self.is_training:
reuse = self._index > 0
if reuse is True:
# clear old name_scope and re-enter the current variable_scope
# clear old name_scope (due to the existing variable_scope)
# and re-enter the current variable_scope
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
else:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name))
for c in self._ctxs:
c.__enter__()
......
......@@ -10,7 +10,7 @@ from six.moves import range
import tensorflow as tf
from .predict import PredictorFactory
from ..graph_builder.predictor_factory import PredictorFactory
from .config import TrainConfig
from ..utils import logger
from ..callbacks import Callback, Callbacks, MaintainStepCounter
......@@ -217,6 +217,7 @@ class Trainer(object):
"""
The variable scope name a predictor should be built in.
"""
# TODO graphbuilder knows it
return ""
def get_predictor(self, input_names, output_names, tower=0):
......@@ -229,7 +230,8 @@ class Trainer(object):
an :class:`OnlinePredictor`.
"""
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(self)
self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
nr_tower = len(self.config.predict_tower)
if nr_tower < tower:
logger.warn(
......
......@@ -17,7 +17,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
Expect ``config.data`` to be a :class:`FeedfreeInput`.
"""
@deprecated("Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)")
......
......@@ -54,7 +54,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
devices: a list of devices to be used. By default will use GPUs in ``towers``.
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment