Commit 93d4efb5 authored by ksellesk's avatar ksellesk Committed by Yuxin Wu

Fix valid_for_nccl: nccl supports fp16 now (#926)

* Fix valid_for_nccl: nccl supports fp16 now

* Check TF version first for valid_for_nccl
parent 9a777e98
......@@ -13,6 +13,7 @@ from contextlib import contextmanager
from ..utils import logger
from ..tfutils.tower import TrainTowerContext
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.common import get_tf_version_tuple
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
......@@ -228,7 +229,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
self._mode = 'cpu'
dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]])
valid_for_nccl = all([k in [tf.float32, tf.float64] for k in dtypes])
dtypes_nccl_supported = [tf.float32, tf.float64]
if get_tf_version_tuple() >= (1, 8):
dtypes_nccl_supported.append(tf.float16)
valid_for_nccl = all([k in dtypes_nccl_supported for k in dtypes])
if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment