Commit 93d4efb5 authored by ksellesk's avatar ksellesk Committed by Yuxin Wu

Fix valid_for_nccl: nccl supports fp16 now (#926)

* Fix valid_for_nccl: nccl supports fp16 now

* Check TF version first for valid_for_nccl
parent 9a777e98
...@@ -13,6 +13,7 @@ from contextlib import contextmanager ...@@ -13,6 +13,7 @@ from contextlib import contextmanager
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TrainTowerContext from ..tfutils.tower import TrainTowerContext
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..tfutils.common import get_tf_version_tuple
from .utils import ( from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable, LeastLoadedDeviceSetter, override_to_local_variable,
...@@ -228,7 +229,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -228,7 +229,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
self._mode = 'cpu' self._mode = 'cpu'
dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]]) dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]])
valid_for_nccl = all([k in [tf.float32, tf.float64] for k in dtypes]) dtypes_nccl_supported = [tf.float32, tf.float64]
if get_tf_version_tuple() >= (1, 8):
dtypes_nccl_supported.append(tf.float16)
valid_for_nccl = all([k in dtypes_nccl_supported for k in dtypes])
if self._mode == 'nccl' and not valid_for_nccl: if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'") logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu' self._mode = 'cpu'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment