Commit e66857ba authored by Yuxin Wu's avatar Yuxin Wu

use reuse context instead of reuse_variables()

parent e1278514
......@@ -182,10 +182,9 @@ class FeedfreeInferenceRunner(Callback):
def _setup_graph(self):
self._find_input_tensors() # tensors
tf.get_variable_scope().reuse_variables()
# overwrite the FeedfreeInferenceRunner scope
with tf.name_scope(None), \
# overwrite the FeedfreeInferenceRunner name scope
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
......
......@@ -152,11 +152,12 @@ def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
"""
for k in towers:
for idx, k in enumerate(towers):
logger.info(
"Building prediction graph for towerid={} with prefix='{}' ...".format(k, prefix))
towername = TowerContext.get_predict_tower_name(prefix, k)
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
TowerContext(towername, is_training=False), \
tf.variable_scope(tf.get_variable_scope(),
reuse=True if idx > 0 else None):
build_tower_fn(k)
tf.get_variable_scope().reuse_variables()
......@@ -71,16 +71,17 @@ class DataParallelOfflinePredictor(OnlinePredictor):
sess = tf.Session(config=config.session_config)
input_var_names = []
output_vars = []
for k in towers:
for idx, k in enumerate(towers):
towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders(
prefix=towername + '-')
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
TowerContext(towername, is_training=False), \
tf.variable_scope(tf.get_variable_scope(),
reuse=True if idx > 0 else None):
config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
[towername + '/' + n
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment