Commit 4a88dfc3 authored by Yuxin Wu's avatar Yuxin Wu

always reuse when not training. otherwise reuse can be set back to False in TF<1.1 (fix #277)

parent 789f082f
...@@ -109,10 +109,9 @@ class InferenceRunnerBase(Callback): ...@@ -109,10 +109,9 @@ class InferenceRunnerBase(Callback):
in_tensors = self._find_input_tensors() in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors assert isinstance(in_tensors, list), in_tensors
with tf.variable_scope(tf.get_variable_scope(), reuse=True): def fn(_):
def fn(_): self.trainer.model.build_graph(in_tensors)
self.trainer.model.build_graph(in_tensors) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._feed_tensors = self._find_feed_tensors() self._feed_tensors = self._find_feed_tensors()
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
......
...@@ -107,8 +107,9 @@ class TowerContext(object): ...@@ -107,8 +107,9 @@ class TowerContext(object):
self._ctxs.append(tf.variable_scope(self._name)) self._ctxs.append(tf.variable_scope(self._name))
else: else:
# use existing variable scope # use existing variable scope
reuse = self.index > 0 or (not self.is_training)
self._ctxs.append(tf.variable_scope( self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=self.index > 0)) tf.get_variable_scope(), reuse=reuse))
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device)) self._ctxs.append(tf.device(self._device))
for c in self._ctxs: for c in self._ctxs:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment