Commit 7a0e8747 authored by Yuxin Wu's avatar Yuxin Wu

build_graph with ctx

parent aabab2cc
......@@ -12,61 +12,6 @@ from ..utils import *
__all__ = ['Callbacks']
# --- Test-Callback related stuff seems not very useful.
@contextmanager
def create_test_graph(trainer):
model = trainer.model
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = model.get_input_vars()
model.build_graph(input_vars, False)
cost = model.get_cost()
yield Gtest
@contextmanager
def create_test_session(trainer):
""" create a test-time session from trainer"""
with create_test_graph(trainer):
with tf.Session() as sess:
yield sess
class TestCallbackContext(object):
"""
A class holding the context needed for running TestCallback
"""
def __init__(self):
self.sess = None
@contextmanager
def create_context(self, trainer):
if self.sess is None:
with create_test_session(trainer) as sess:
self.sess = sess
self.graph = sess.graph
# no tower in test graph. just keep it as what it is
self.saver = tf.train.Saver()
with self.graph.as_default(), self.sess.as_default():
yield
# TODO also do this for after_train?
def restore_checkpoint(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before all TestCallback?")
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
@contextmanager
def test_context(self):
with self.graph.as_default(), self.sess.as_default():
yield
# ---
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from copy import copy
import re
from .model_desc import get_current_tower_context
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from ._common import layer_register
......@@ -54,6 +55,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
# XXX a hack to handle training tower & prediction tower together....
emaname = 'EMA'
#ctx = get_current_model_context()
if not batch_mean.name.startswith('towerp'):
# training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
......
......@@ -6,14 +6,54 @@
from abc import ABCMeta, abstractmethod
import tensorflow as tf
from collections import namedtuple
import inspect
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils import *
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph',
'get_current_tower_context', 'TowerContext']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
@property
def is_training(self):
return self._is_training
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
return self._scope.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if len(self._name):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
class ModelDesc(object):
""" Base class for a model description """
__metaclass__ = ABCMeta
......@@ -49,7 +89,7 @@ class ModelDesc(object):
def _get_input_vars(self):
""":returns: a list of InputVar """
def build_graph(self, model_inputs, is_training):
def build_graph(self, model_inputs):
"""
Setup the whole graph.
......@@ -57,10 +97,15 @@ class ModelDesc(object):
:param is_training: a boolean
:returns: the cost to minimize. a scalar variable
"""
self._build_graph(model_inputs, is_training)
if len(inspect.getargspec(self._build_graph).args) == 3:
logger.warn("_build_graph(self, input_vars, is_training) is deprecated! \
Use _build_graph(self, input_vars) and get_current_tower_context().is_training instead.")
self._build_graph(model_inputs, get_current_tower_context().is_training)
else:
self._build_graph(model_inputs)
@abstractmethod
def _build_graph(self, inputs, is_training):
def _build_graph(self, inputs):
pass
def get_cost(self):
......
......@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
import six
from ..models import TowerContext
from ..utils import logger
from ..tfutils import get_vars_by_names
......@@ -88,7 +89,8 @@ class OfflinePredictor(OnlinePredictor):
self.graph = tf.Graph()
with self.graph.as_default():
input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False)
with TowerContext('', False):
config.model.build_graph(input_vars)
input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(config.output_var_names)
......@@ -99,7 +101,7 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input)
def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
def build_multi_tower_prediction_graph(model, towers):
"""
:param towers: a list of gpu relative id.
"""
......@@ -107,26 +109,24 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
for k in towers:
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'),\
tf.name_scope('{}{}'.format(prefix, k)):
model._build_graph(input_vars, False)
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('towerp{}'.format(k)):
model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
class MultiTowerOfflinePredictor(OnlinePredictor):
PREFIX = 'towerp'
def __init__(self, config, towers):
self.graph = tf.Graph()
self.predictors = []
with self.graph.as_default():
# TODO backup summary keys?
build_multi_tower_prediction_graph(config.model, towers, self.PREFIX)
build_multi_tower_prediction_graph(config.model, towers)
self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess)
input_vars = get_vars_by_names(config.input_var_names)
# use the first tower for compatible PredictorBase interface
for k in towers:
output_vars = get_vars_by_names(
['{}{}/'.format(self.PREFIX, k) + n \
......@@ -135,6 +135,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess, input_vars, output_vars, config.return_input))
def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp)
def get_predictors(self, n):
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
import itertools, re
from six.moves import zip, range
from ..models import TowerContext
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
......@@ -26,7 +27,7 @@ class MultiGPUTrainer(QueueInputTrainer):
@staticmethod
def _average_grads(tower_grads):
ret = []
with tf.name_scope('average_grad'):
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
......@@ -44,12 +45,12 @@ class MultiGPUTrainer(QueueInputTrainer):
grad_list = []
for idx, t in enumerate(self.config.tower):
with tf.device('/gpu:{}'.format(t)), \
tf.name_scope('tower{}'.format(idx)) as scope:
TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True)
self.model.build_graph(model_inputs)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster?
......@@ -92,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
with tf.name_scope('async_scale_grad'):
with tf.name_scope('AsyncScaleGrad'):
return [(grad / len(self.config.tower) if grad is not None else None, var)
for grad, var in grads]
grad_list = map(scale, grad_list)
......
......@@ -10,19 +10,18 @@ from six.moves import zip
from .base import Trainer
from ..dataflow.common import RepeatedData
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..models import TowerContext
from ..utils import *
from ..tfutils import *
from ..tfutils.summary import add_moving_summary
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
class PredictorFactory(object):
""" Make predictors for a trainer"""
PREFIX = 'towerp'
def __init__(self, sess, model, towers):
"""
......@@ -42,7 +41,7 @@ class PredictorFactory(object):
self._build_predict_tower()
tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names)
output_names = ['{}{}/'.format(self.PREFIX, tower) + n for n in output_names]
output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
......@@ -52,7 +51,7 @@ class PredictorFactory(object):
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
build_multi_tower_prediction_graph(
self.model, self.towers, prefix=self.PREFIX)
self.model, self.towers)
self.tower_built = True
class SimpleTrainer(Trainer):
......@@ -64,7 +63,8 @@ class SimpleTrainer(Trainer):
def train(self):
model = self.model
self.input_vars = model.get_input_vars()
model.build_graph(self.input_vars, True)
with TowerContext(''):
model.build_graph(self.input_vars)
cost_var = model.get_cost() # TODO assert scalar
add_moving_summary(cost_var)
......@@ -180,7 +180,8 @@ class QueueInputTrainer(Trainer):
#self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
self.model.build_graph(self.dequed_inputs, True)
with TowerContext(''):
self.model.build_graph(self.dequed_inputs)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment