Commit 3ef33a34 authored by Yuxin Wu's avatar Yuxin Wu

Handle variables of unknown static shape (#738)

parent 15c0e160
...@@ -79,6 +79,10 @@ class LeastLoadedDeviceSetter(object): ...@@ -79,6 +79,10 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes), key=operator.itemgetter(1)) self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index] device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements() var_size = op.outputs[0].get_shape().num_elements()
if var_size is None:
logger.warn("[LeastLoadedDeviceSetter] Shape of variable {} is not fully defined!".format(op.name))
var_size = 0
self.ps_sizes[device_index] += var_size self.ps_sizes[device_index] += var_size
return sanitize_name(device_name) return sanitize_name(device_name)
......
...@@ -28,9 +28,17 @@ def describe_trainable_vars(): ...@@ -28,9 +28,17 @@ def describe_trainable_vars():
continue continue
shape = v.get_shape() shape = v.get_shape()
ele = shape.num_elements() ele = shape.num_elements()
if ele is None:
logger.warn("Shape of variable {} is not fully defined but {}.".format(v.name, shape))
ele = 0
try:
shape = shape.as_list()
except ValueError:
shape = '<unknown>'
total += ele total += ele
total_bytes += ele * v.dtype.size total_bytes += ele * v.dtype.size
data.append([v.name, shape.as_list(), ele, v.device, v.dtype.base_dtype.name]) data.append([v.name, shape, ele, v.device, v.dtype.base_dtype.name])
headers = ['name', 'shape', 'dim', 'device', 'dtype'] headers = ['name', 'shape', 'dim', 'device', 'dtype']
dtypes = set([x[4] for x in data]) dtypes = set([x[4] for x in data])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment