Commit da143e0f authored by Yuxin Wu's avatar Yuxin Wu

Load nccl so (fix #913)

parent d1cc5a4a
...@@ -227,6 +227,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -227,6 +227,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
"Cross-GPU BatchNorm is only supported in TF>=1.10 ." \ "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360" "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so
except Exception:
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.nccl.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name) shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
batch_mean = gen_nccl_ops.nccl_all_reduce( batch_mean = gen_nccl_ops.nccl_all_reduce(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment