Commit 3d28966c authored by Yuxin Wu's avatar Yuxin Wu

support XLA

parent 35beb43c
...@@ -258,6 +258,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -258,6 +258,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
if self._mode in ['nccl', 'hierarchical']: if self._mode in ['nccl', 'hierarchical']:
all_grads, all_vars = split_grad_list(grad_list) all_grads, all_vars = split_grad_list(grad_list)
# use allreduce from tf-benchmarks
# from .batch_allreduce import AllReduceSpecAlgorithm
# algo = AllReduceSpecAlgorithm('nccl', list(range(8)), 0, 10)
# all_grads, warmup_ops = algo.batch_all_reduce(all_grads, 1, True, False)
# print("WARMUP OPS", warmup_ops)
if self._mode == 'nccl': if self._mode == 'nccl':
all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param
......
...@@ -176,6 +176,9 @@ class SingleCostTrainer(TowerTrainer): ...@@ -176,6 +176,9 @@ class SingleCostTrainer(TowerTrainer):
AGGREGATION_METHOD = tf.AggregationMethod.DEFAULT AGGREGATION_METHOD = tf.AggregationMethod.DEFAULT
"""See `tf.gradients`. """ """See `tf.gradients`. """
XLA_COMPILE = False
""" Use :func:`xla.compile` to compile the tower function. """
@call_only_once @call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn): def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
""" """
...@@ -246,4 +249,26 @@ class SingleCostTrainer(TowerTrainer): ...@@ -246,4 +249,26 @@ class SingleCostTrainer(TowerTrainer):
grads = FilterNoneGrad().process(grads) grads = FilterNoneGrad().process(grads)
return grads return grads
return get_grad_fn if not self.XLA_COMPILE:
return get_grad_fn
else:
from tensorflow.contrib.compiler import xla
def xla_get_grad_fn():
def xla_func():
grads = get_grad_fn()
# unpack, because the return value
# of xla function cannot have nested structure
grads = [x[0] for x in grads]
return grads
grads_no_vars = xla.compile(xla_func)
# repack again
ctx = get_current_tower_context()
if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
return list(zip(grads_no_vars, varlist))
return xla_get_grad_fn
...@@ -152,6 +152,15 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -152,6 +152,15 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
List of GPU ids. List of GPU ids.
""" """
BROADCAST_EVERY_EPOCH = True
"""
Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables
are supposed to be in-sync).
But this cheap operation may help prevent
certain numerical issues in practice.
"""
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, average=True, mode=None, use_nccl=None): def __init__(self, gpus, average=True, mode=None, use_nccl=None):
""" """
...@@ -186,7 +195,9 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -186,7 +195,9 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
cb = RunOp( cb = RunOp(
post_init_op, post_init_op,
run_before=True, run_as_trigger=True, verbose=True) run_before=True,
run_as_trigger=self.BROADCAST_EVERY_EPOCH,
verbose=True)
return [cb] return [cb]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment