Commit 1139854d authored by Yuxin Wu's avatar Yuxin Wu

handle model with different parameter dtypes

parent 92a9315e
...@@ -218,10 +218,17 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -218,10 +218,17 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='cpu'.") logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='cpu'.")
self._mode = 'cpu' self._mode = 'cpu'
dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]])
valid_for_nccl = all([k in [tf.float32, tf.float64] for k in dtypes])
if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu'
if self._mode in ['nccl', 'hierarchical']: if self._mode in ['nccl', 'hierarchical']:
all_grads, all_vars = split_grad_list(grad_list) all_grads, all_vars = split_grad_list(grad_list)
if self._mode == 'nccl': if self._mode == 'nccl':
all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param x 2 all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param
else: else:
packer = GradientPacker(len(raw_devices)) packer = GradientPacker(len(raw_devices))
succ = packer.compute_strategy(all_grads[0]) succ = packer.compute_strategy(all_grads[0])
......
...@@ -23,7 +23,6 @@ def describe_trainable_vars(): ...@@ -23,7 +23,6 @@ def describe_trainable_vars():
total = 0 total = 0
total_bytes = 0 total_bytes = 0
data = [] data = []
devices = set()
for v in train_vars: for v in train_vars:
if v.name.startswith('tower'): if v.name.startswith('tower'):
continue continue
...@@ -31,16 +30,23 @@ def describe_trainable_vars(): ...@@ -31,16 +30,23 @@ def describe_trainable_vars():
ele = shape.num_elements() ele = shape.num_elements()
total += ele total += ele
total_bytes += ele * v.dtype.size total_bytes += ele * v.dtype.size
devices.add(v.device) data.append([v.name, shape.as_list(), ele, v.device, v.dtype.base_dtype.name])
data.append([v.name, shape.as_list(), ele, v.device]) headers = ['name', 'shape', 'dim', 'device', 'dtype']
dtypes = set([x[4] for x in data])
if len(dtypes) == 1:
for x in data:
del x[4]
del headers[4]
devices = set([x[3] for x in data])
if len(devices) == 1: if len(devices) == 1:
# don't log the device if all vars on the same device # don't log the device if all vars on the same device
for d in data: for x in data:
d.pop() del x[3]
table = tabulate(data, headers=['name', 'shape', 'dim']) del headers[3]
else:
table = tabulate(data, headers=['name', 'shape', 'dim', 'device']) table = tabulate(data, headers=headers)
size_mb = total_bytes / 1024.0**2 size_mb = total_bytes / 1024.0**2
summary_msg = colored( summary_msg = colored(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment