Commit e66857ba authored by Yuxin Wu's avatar Yuxin Wu

use reuse context instead of reuse_variables()

parent e1278514
...@@ -182,10 +182,9 @@ class FeedfreeInferenceRunner(Callback): ...@@ -182,10 +182,9 @@ class FeedfreeInferenceRunner(Callback):
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
tf.get_variable_scope().reuse_variables() # overwrite the FeedfreeInferenceRunner name scope
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
# overwrite the FeedfreeInferenceRunner scope tf.name_scope(None), \
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_): def fn(_):
self.trainer.model.build_graph(self._input_tensors) self.trainer.model.build_graph(self._input_tensors)
......
...@@ -152,11 +152,12 @@ def build_prediction_graph(build_tower_fn, towers=[0], prefix=''): ...@@ -152,11 +152,12 @@ def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
prefix: an extra prefix in tower name. The final tower prefix will be prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`. determined by :meth:`TowerContext.get_predict_tower_name`.
""" """
for k in towers: for idx, k in enumerate(towers):
logger.info( logger.info(
"Building prediction graph for towerid={} with prefix='{}' ...".format(k, prefix)) "Building prediction graph for towerid={} with prefix='{}' ...".format(k, prefix))
towername = TowerContext.get_predict_tower_name(prefix, k) towername = TowerContext.get_predict_tower_name(prefix, k)
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False): TowerContext(towername, is_training=False), \
tf.variable_scope(tf.get_variable_scope(),
reuse=True if idx > 0 else None):
build_tower_fn(k) build_tower_fn(k)
tf.get_variable_scope().reuse_variables()
...@@ -71,16 +71,17 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -71,16 +71,17 @@ class DataParallelOfflinePredictor(OnlinePredictor):
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
input_var_names = [] input_var_names = []
output_vars = [] output_vars = []
for k in towers: for idx, k in enumerate(towers):
towername = PREDICT_TOWER + str(k) towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders( input_vars = config.model.build_placeholders(
prefix=towername + '-') prefix=towername + '-')
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False): TowerContext(towername, is_training=False), \
tf.variable_scope(tf.get_variable_scope(),
reuse=True if idx > 0 else None):
config.model.build_graph(input_vars) config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars]) input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names( output_vars.extend(get_tensors_by_names(
[towername + '/' + n [towername + '/' + n
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment