Commit 7b4980c9 authored by Yuxin Wu's avatar Yuxin Wu

make XLA work on more models

parent 3d28966c
......@@ -222,19 +222,26 @@ def add_moving_summary(*args, **kwargs):
# allow ctx to be none
if ctx is not None and not ctx.is_main_training_tower:
return []
graph = tf.get_default_graph()
try:
control_flow_ctx = graph._get_control_flow_context()
# XLA does not support summaries anyway
# However, this function will generate unnecessary dependency edges,
# which makes the tower function harder to compile under XLA, so we skip it
if control_flow_ctx is not None and control_flow_ctx.IsXLAContext():
return
except Exception:
pass
if tf.get_variable_scope().reuse is True:
logger.warn("add_moving_summary() called under reuse=True scope, ignored.")
return []
if len(args) == 1 and isinstance(args[0], (list, tuple)):
logger.warn("add_moving_summary() takes positional args instead of an iterable of tensors!")
args = args[0]
for x in args:
assert isinstance(x, (tf.Tensor, tf.Variable)), x
assert x.get_shape().ndims == 0, \
"add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape())
# TODO variable not saved under distributed
ema_ops = []
for c in args:
......@@ -245,7 +252,8 @@ def add_moving_summary(*args, **kwargs):
# assign_moving_average creates variables with op names, therefore clear ns first.
with _enter_vs_reuse_ns('EMA') as vs:
ema_var = tf.get_variable(name, shape=c.shape, dtype=c.dtype,
initializer=tf.constant_initializer(), trainable=False)
initializer=tf.constant_initializer(),
trainable=False)
ns = vs.original_name_scope
with tf.name_scope(ns): # reuse VS&NS so that EMA_1 won't appear
ema_op = moving_averages.assign_moving_average(
......
......@@ -177,7 +177,16 @@ class SingleCostTrainer(TowerTrainer):
"""See `tf.gradients`. """
XLA_COMPILE = False
""" Use :func:`xla.compile` to compile the tower function. """
""" Use :func:`xla.compile` to compile the tower function.
Note that XLA has very strong requirements on the tower function, e.g.:
1. limited op support
2. inferrable shape
3. no summary support
and many tower functions cannot be compiled by XLA.
Don't use it if you don't understand it.
"""
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
......@@ -230,9 +239,13 @@ class SingleCostTrainer(TowerTrainer):
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
inputs = input.get_input_tensors()
def compute_grad_from_inputs(*inputs):
cost = get_cost_fn(*inputs)
assert isinstance(cost, tf.Tensor), cost
assert cost.shape.ndims == 0, "Cost must be a scalar, but found {}!".format(cost)
if not ctx.is_training:
return None # this is the tower function, could be called for inference
......@@ -250,25 +263,22 @@ class SingleCostTrainer(TowerTrainer):
return grads
if not self.XLA_COMPILE:
return get_grad_fn
return compute_grad_from_inputs(*inputs)
else:
from tensorflow.contrib.compiler import xla
def xla_get_grad_fn():
def xla_func():
grads = get_grad_fn()
grads = compute_grad_from_inputs(*inputs)
# unpack, because the return value
# of xla function cannot have nested structure
grads = [x[0] for x in grads]
return grads
grads_no_vars = xla.compile(xla_func)
# repack again
ctx = get_current_tower_context()
if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
return list(zip(grads_no_vars, varlist))
return xla_get_grad_fn
return get_grad_fn
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment