Commit 61a1d3f2 authored by Yuxin Wu's avatar Yuxin Wu

Translate placeholder's tensor name to op name when contructing TensorSpec (fix #1118)

parent 2b4ec72e
......@@ -259,13 +259,13 @@ def finalize_configs(is_training):
else:
assert 'OMPI_COMM_WORLD_SIZE' not in os.environ
ngpu = get_num_gpu()
assert ngpu > 0, "Has to train with GPU!"
assert ngpu % 8 == 0 or 8 % ngpu == 0, "Can only train with 1,2,4 or >=8 GPUs, but found {} GPUs".format(ngpu)
else:
# autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
ngpu = get_num_gpu()
assert ngpu > 0, "Has to run with GPU!"
if _C.TRAIN.NUM_GPUS is None:
_C.TRAIN.NUM_GPUS = ngpu
else:
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from ..utils.argtools import memoized_method
from ..utils.develop import deprecated
from ..tfutils.common import get_op_tensor_name
from ..compat import backport_tensor_spec, tfv1
TensorSpec = backport_tensor_spec()
......@@ -88,7 +89,7 @@ class ModelDescBase(object):
assert "Placeholder" in p.op.type, \
"inputs() have to return TensorSpec or placeholders! Found {} instead.".format(p)
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [TensorSpec(shape=p.shape, dtype=p.dtype, name=p.name) for p in inputs]
return [TensorSpec(shape=p.shape, dtype=p.dtype, name=get_op_tensor_name(p.name)[0]) for p in inputs]
@property
def input_names(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment