Commit 7a0e8747 authored by Yuxin Wu's avatar Yuxin Wu

build_graph with ctx

parent aabab2cc
...@@ -12,61 +12,6 @@ from ..utils import * ...@@ -12,61 +12,6 @@ from ..utils import *
__all__ = ['Callbacks'] __all__ = ['Callbacks']
# --- Test-Callback related stuff seems not very useful.
@contextmanager
def create_test_graph(trainer):
model = trainer.model
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = model.get_input_vars()
model.build_graph(input_vars, False)
cost = model.get_cost()
yield Gtest
@contextmanager
def create_test_session(trainer):
""" create a test-time session from trainer"""
with create_test_graph(trainer):
with tf.Session() as sess:
yield sess
class TestCallbackContext(object):
"""
A class holding the context needed for running TestCallback
"""
def __init__(self):
self.sess = None
@contextmanager
def create_context(self, trainer):
if self.sess is None:
with create_test_session(trainer) as sess:
self.sess = sess
self.graph = sess.graph
# no tower in test graph. just keep it as what it is
self.saver = tf.train.Saver()
with self.graph.as_default(), self.sess.as_default():
yield
# TODO also do this for after_train?
def restore_checkpoint(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before all TestCallback?")
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
@contextmanager
def test_context(self):
with self.graph.as_default(), self.sess.as_default():
yield
# ---
class CallbackTimeLogger(object): class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
self.times = [] self.times = []
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from copy import copy from copy import copy
import re import re
from .model_desc import get_current_tower_context
from ..utils import logger, EXTRA_SAVE_VARS_KEY from ..utils import logger, EXTRA_SAVE_VARS_KEY
from ._common import layer_register from ._common import layer_register
...@@ -54,6 +55,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -54,6 +55,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
# XXX a hack to handle training tower & prediction tower together.... # XXX a hack to handle training tower & prediction tower together....
emaname = 'EMA' emaname = 'EMA'
#ctx = get_current_model_context()
if not batch_mean.name.startswith('towerp'): if not batch_mean.name.startswith('towerp'):
# training tower # training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
......
...@@ -6,14 +6,54 @@ ...@@ -6,14 +6,54 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import inspect
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUT_VARS_KEY
from ..tfutils import * from ..tfutils import *
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph'] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph',
'get_current_tower_context', 'TowerContext']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name']) InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
@property
def is_training(self):
return self._is_training
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -49,7 +89,7 @@ class ModelDesc(object): ...@@ -49,7 +89,7 @@ class ModelDesc(object):
def _get_input_vars(self): def _get_input_vars(self):
""":returns: a list of InputVar """ """:returns: a list of InputVar """
def build_graph(self, model_inputs, is_training): def build_graph(self, model_inputs):
""" """
Setup the whole graph. Setup the whole graph.
...@@ -57,10 +97,15 @@ class ModelDesc(object): ...@@ -57,10 +97,15 @@ class ModelDesc(object):
:param is_training: a boolean :param is_training: a boolean
:returns: the cost to minimize. a scalar variable :returns: the cost to minimize. a scalar variable
""" """
self._build_graph(model_inputs, is_training) if len(inspect.getargspec(self._build_graph).args) == 3:
logger.warn("_build_graph(self, input_vars, is_training) is deprecated! \
Use _build_graph(self, input_vars) and get_current_tower_context().is_training instead.")
self._build_graph(model_inputs, get_current_tower_context().is_training)
else:
self._build_graph(model_inputs)
@abstractmethod @abstractmethod
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs):
pass pass
def get_cost(self): def get_cost(self):
......
...@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty ...@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf import tensorflow as tf
import six import six
from ..models import TowerContext
from ..utils import logger from ..utils import logger
from ..tfutils import get_vars_by_names from ..tfutils import get_vars_by_names
...@@ -88,7 +89,8 @@ class OfflinePredictor(OnlinePredictor): ...@@ -88,7 +89,8 @@ class OfflinePredictor(OnlinePredictor):
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
input_vars = config.model.get_input_vars() input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False) with TowerContext('', False):
config.model.build_graph(input_vars)
input_vars = get_vars_by_names(config.input_var_names) input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(config.output_var_names) output_vars = get_vars_by_names(config.output_var_names)
...@@ -99,7 +101,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -99,7 +101,7 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
def build_multi_tower_prediction_graph(model, towers, prefix='towerp'): def build_multi_tower_prediction_graph(model, towers):
""" """
:param towers: a list of gpu relative id. :param towers: a list of gpu relative id.
""" """
...@@ -107,26 +109,24 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'): ...@@ -107,26 +109,24 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
for k in towers: for k in towers:
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'),\ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('{}{}'.format(prefix, k)): TowerContext('towerp{}'.format(k)):
model._build_graph(input_vars, False) model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
class MultiTowerOfflinePredictor(OnlinePredictor): class MultiTowerOfflinePredictor(OnlinePredictor):
PREFIX = 'towerp'
def __init__(self, config, towers): def __init__(self, config, towers):
self.graph = tf.Graph() self.graph = tf.Graph()
self.predictors = [] self.predictors = []
with self.graph.as_default(): with self.graph.as_default():
# TODO backup summary keys? # TODO backup summary keys?
build_multi_tower_prediction_graph(config.model, towers, self.PREFIX) build_multi_tower_prediction_graph(config.model, towers)
self.sess = tf.Session(config=config.session_config) self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess) config.session_init.init(self.sess)
input_vars = get_vars_by_names(config.input_var_names) input_vars = get_vars_by_names(config.input_var_names)
# use the first tower for compatible PredictorBase interface
for k in towers: for k in towers:
output_vars = get_vars_by_names( output_vars = get_vars_by_names(
['{}{}/'.format(self.PREFIX, k) + n \ ['{}{}/'.format(self.PREFIX, k) + n \
...@@ -135,6 +135,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -135,6 +135,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess, input_vars, output_vars, config.return_input)) self.sess, input_vars, output_vars, config.return_input))
def _do_call(self, dp): def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp) return self.predictors[0]._do_call(dp)
def get_predictors(self, n): def get_predictors(self, n):
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
import itertools, re import itertools, re
from six.moves import zip, range from six.moves import zip, range
from ..models import TowerContext
from ..utils import * from ..utils import *
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
...@@ -26,7 +27,7 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -26,7 +27,7 @@ class MultiGPUTrainer(QueueInputTrainer):
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
ret = [] ret = []
with tf.name_scope('average_grad'): with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
try: try:
...@@ -44,12 +45,12 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -44,12 +45,12 @@ class MultiGPUTrainer(QueueInputTrainer):
grad_list = [] grad_list = []
for idx, t in enumerate(self.config.tower): for idx, t in enumerate(self.config.tower):
with tf.device('/gpu:{}'.format(t)), \ with tf.device('/gpu:{}'.format(t)), \
tf.name_scope('tower{}'.format(idx)) as scope: TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs) self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True) self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster? # TODO gate_gradienst=0 seems to be faster?
...@@ -92,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -92,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
def scale(grads): def scale(grads):
with tf.name_scope('async_scale_grad'): with tf.name_scope('AsyncScaleGrad'):
return [(grad / len(self.config.tower) if grad is not None else None, var) return [(grad / len(self.config.tower) if grad is not None else None, var)
for grad, var in grads] for grad, var in grads]
grad_list = map(scale, grad_list) grad_list = map(scale, grad_list)
......
...@@ -10,19 +10,18 @@ from six.moves import zip ...@@ -10,19 +10,18 @@ from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..models import TowerContext
from ..utils import * from ..utils import *
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
PREFIX = 'towerp'
def __init__(self, sess, model, towers): def __init__(self, sess, model, towers):
""" """
...@@ -42,7 +41,7 @@ class PredictorFactory(object): ...@@ -42,7 +41,7 @@ class PredictorFactory(object):
self._build_predict_tower() self._build_predict_tower()
tower = self.towers[tower % len(self.towers)] tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
output_names = ['{}{}/'.format(self.PREFIX, tower) + n for n in output_names] output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names) output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
...@@ -52,7 +51,7 @@ class PredictorFactory(object): ...@@ -52,7 +51,7 @@ class PredictorFactory(object):
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
build_multi_tower_prediction_graph( build_multi_tower_prediction_graph(
self.model, self.towers, prefix=self.PREFIX) self.model, self.towers)
self.tower_built = True self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
...@@ -64,8 +63,9 @@ class SimpleTrainer(Trainer): ...@@ -64,8 +63,9 @@ class SimpleTrainer(Trainer):
def train(self): def train(self):
model = self.model model = self.model
self.input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
model.build_graph(self.input_vars, True) with TowerContext(''):
cost_var = model.get_cost() # TODO assert scalar model.build_graph(self.input_vars)
cost_var = model.get_cost() # TODO assert scalar
add_moving_summary(cost_var) add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
...@@ -180,8 +180,9 @@ class QueueInputTrainer(Trainer): ...@@ -180,8 +180,9 @@ class QueueInputTrainer(Trainer):
#self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3], #self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False), #dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)] #tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
self.model.build_graph(self.dequed_inputs, True) with TowerContext(''):
cost_var = self.model.get_cost() self.model.build_graph(self.dequed_inputs)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients( grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE cost_var, gate_gradients=0) # GATE_NONE
add_moving_summary(cost_var) add_moving_summary(cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment