Commit 21ded816 authored by Yuxin Wu's avatar Yuxin Wu

Fix stagingarea when working with variables

parent bab16832
......@@ -87,7 +87,10 @@ class DataParallelBuilder(GraphBuilder):
is_training=True,
index=idx,
vs_name=tower_names[idx] if usevs else ''):
if len(str(device)) < 10: # a device function doesn't have good string description
logger.info("Building graph for training tower {} on device {}...".format(idx, device))
else:
logger.info("Building graph for training tower {} ...".format(idx))
# When use_vs is True, use LOCAL_VARIABLES,
# so these duplicated variables won't be saved by default.
......
......@@ -539,7 +539,15 @@ class StagingInput(FeedfreeInput):
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
dtypes = [x.dtype for x in inputs]
# Putting variables to stagingarea will cause trouble
dtypes = []
for idx in range(len(inputs)):
dtype = inputs[idx].dtype
if dtype.base_dtype != dtype: # is reference type
inputs[idx] = tf.identity(inputs[idx])
dtypes.append(dtype.base_dtype)
stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment