Commit dadd971c authored by Yuxin Wu's avatar Yuxin Wu

clean-up TowerContext, pass tower index into it. (a better solution to #310)

parent c2edd999
......@@ -176,6 +176,7 @@ class PredictorTowerBuilder(object):
tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1.
"""
toweridx = max(tower, 0) # if CPU, named the tower as 0
towername = TowerContext.get_predict_tower_name(tower, self._prefix)
if self._prefix:
msg = "Building predictor graph {} on gpu={} with prefix='{}' ...".format(
......@@ -187,7 +188,8 @@ class PredictorTowerBuilder(object):
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
with tf.name_scope(None), \
freeze_collection(TOWER_FREEZE_KEYS), \
TowerContext(towername, device=device, is_training=False):
tf.device(device), \
TowerContext(towername, is_training=False, index=toweridx):
self._fn(tower)
# useful only when the placeholders don't have tower prefix
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re
from ..utils.naming import PREDICT_TOWER
__all__ = ['get_current_tower_context', 'TowerContext']
......@@ -16,24 +15,27 @@ class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name,
device=None, is_training=None,
is_training=None,
index=0,
var_strategy='shared',
vs_name=None):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
"""
self._name = tower_name
self._device = device
if is_training is None:
is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = is_training
self._is_training = bool(is_training)
self._index = index
assert var_strategy in ['replicated', 'shared'], var_strategy
self._var_strategy = var_strategy
......@@ -49,11 +51,11 @@ class TowerContext(object):
@property
def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0')
return self.is_training and self._index == 0
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
return self._index == 0
@property
def is_training(self):
......@@ -67,37 +69,17 @@ class TowerContext(object):
def name(self):
return self._name
# TODO remove this and add something like `tower.variables`
# variable_scope name
@property
def vs_name(self):
return self._vs_name
# TODO pass index into the constructor
@property
def index(self):
if self._name == '':
return 0
idx = re.findall('[0-9]+$', self._name)
if len(idx) == 0:
return 0
return int(idx[0])
@property
def device(self):
return self._device
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
if name.startswith(PREDICT_TOWER):
predict_tower_prefix = '{}[0-9]+/'.format(PREDICT_TOWER)
newname = re.sub(predict_tower_prefix, '', name)
try:
return graph.get_tensor_by_name(newname)
except KeyError:
newname = re.sub(predict_tower_prefix, 'tower0/', name)
return graph.get_tensor_by_name(newname)
return self._index
# TODO something similar for training
@staticmethod
def get_predict_tower_name(towerid=0, prefix=''):
"""
......@@ -124,15 +106,14 @@ class TowerContext(object):
self._ctxs.append(tf.variable_scope(self.vs_name))
else:
if self.is_training:
reuse = self.index > 0
reuse = self._index > 0
if reuse is True:
# clear old name_scope and re-enter the current variable_scope
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
self._ctxs.append(tf.name_scope(self._name))
if self._device is not None:
self._ctxs.append(tf.device(self._device))
for c in self._ctxs:
c.__enter__()
......
......@@ -81,9 +81,10 @@ class MultiGPUTrainerBase(Trainer):
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext(
with tf.device(device), TowerContext(
'tower{}'.format(idx),
device=device, is_training=True,
is_training=True,
index=idx,
var_strategy=var_strategy,
vs_name=vs_names[idx]):
if idx == t:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment