Commit a62e68a3 authored by Yuxin Wu's avatar Yuxin Wu

Remove prefix for first tower in replicated mode. Support inference now.

parent 7290319f
......@@ -46,13 +46,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params:
para_name = p.name
# in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.name)):
if ctx.has_own_variables and (not para_name.startswith(ctx.vs_name)):
continue
if re.search(regex, para_name):
costs.append(func(p))
_log_regularizer(para_name)
if not costs:
return 0
return tf.constant(0, dtype=tf.float32, name='empty_regularize_cost')
return tf.add_n(costs, name=name)
......
......@@ -59,6 +59,16 @@ class TowerContext(object):
def name(self):
return self._name
# variable_scope name
@property
def vs_name(self):
if self.has_own_variables:
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if self.index > 0:
return self._name
return ""
@property
def index(self):
if self._name == '':
......@@ -103,8 +113,8 @@ class TowerContext(object):
self._ctxs = []
if len(self._name):
if self.has_own_variables:
# open new variable scopes
self._ctxs.append(tf.variable_scope(self._name))
if self.vs_name:
self._ctxs.append(tf.variable_scope(self.vs_name))
else:
# use existing variable scope
reuse = self.index > 0 or (not self.is_training)
......
......@@ -68,9 +68,9 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
# GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables:
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
varlist = [v for v in varlist if v.op.name.startswith(ctx.name + '/')]
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients(
cost,
var_list=varlist,
......
......@@ -244,7 +244,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
super(SyncMultiGPUTrainerReplicated, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
opt = self.model.get_optimizer() # XXX call before build tower to avoid opt under tower scopes.
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
......@@ -252,6 +251,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grads = self._allreduce_grads(grad_list)
train_ops = []
opt = self.model.get_optimizer()
for idx in range(self.config.nr_tower):
with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
......@@ -272,9 +272,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
post_init_ops = []
for v in global_vars:
split_name = v.name.split('/')
if split_name[0] == 'tower0' or not v.name.startswith('tower'):
if not v.name.startswith('tower'):
continue
split_name[0] = 'tower0'
# the master name doesn't have the towerx/ prefix
split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment