Commit 61a1d3f2 authored by Yuxin Wu's avatar Yuxin Wu

Translate placeholder's tensor name to op name when contructing TensorSpec (fix #1118)

parent 2b4ec72e
...@@ -259,13 +259,13 @@ def finalize_configs(is_training): ...@@ -259,13 +259,13 @@ def finalize_configs(is_training):
else: else:
assert 'OMPI_COMM_WORLD_SIZE' not in os.environ assert 'OMPI_COMM_WORLD_SIZE' not in os.environ
ngpu = get_num_gpu() ngpu = get_num_gpu()
assert ngpu > 0, "Has to train with GPU!"
assert ngpu % 8 == 0 or 8 % ngpu == 0, "Can only train with 1,2,4 or >=8 GPUs, but found {} GPUs".format(ngpu) assert ngpu % 8 == 0 or 8 % ngpu == 0, "Can only train with 1,2,4 or >=8 GPUs, but found {} GPUs".format(ngpu)
else: else:
# autotune is too slow for inference # autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
ngpu = get_num_gpu() ngpu = get_num_gpu()
assert ngpu > 0, "Has to run with GPU!"
if _C.TRAIN.NUM_GPUS is None: if _C.TRAIN.NUM_GPUS is None:
_C.TRAIN.NUM_GPUS = ngpu _C.TRAIN.NUM_GPUS = ngpu
else: else:
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from ..utils.argtools import memoized_method from ..utils.argtools import memoized_method
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..tfutils.common import get_op_tensor_name
from ..compat import backport_tensor_spec, tfv1 from ..compat import backport_tensor_spec, tfv1
TensorSpec = backport_tensor_spec() TensorSpec = backport_tensor_spec()
...@@ -88,7 +89,7 @@ class ModelDescBase(object): ...@@ -88,7 +89,7 @@ class ModelDescBase(object):
assert "Placeholder" in p.op.type, \ assert "Placeholder" in p.op.type, \
"inputs() have to return TensorSpec or placeholders! Found {} instead.".format(p) "inputs() have to return TensorSpec or placeholders! Found {} instead.".format(p)
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!" assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [TensorSpec(shape=p.shape, dtype=p.dtype, name=p.name) for p in inputs] return [TensorSpec(shape=p.shape, dtype=p.dtype, name=get_op_tensor_name(p.name)[0]) for p in inputs]
@property @property
def input_names(self): def input_names(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment