Commit a62e68a3 authored by Yuxin Wu's avatar Yuxin Wu

Remove prefix for first tower in replicated mode. Support inference now.

parent 7290319f
...@@ -46,13 +46,13 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -46,13 +46,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params: for p in params:
para_name = p.name para_name = p.name
# in replicated mode, only regularize variables inside this tower # in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.name)): if ctx.has_own_variables and (not para_name.startswith(ctx.vs_name)):
continue continue
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) costs.append(func(p))
_log_regularizer(para_name) _log_regularizer(para_name)
if not costs: if not costs:
return 0 return tf.constant(0, dtype=tf.float32, name='empty_regularize_cost')
return tf.add_n(costs, name=name) return tf.add_n(costs, name=name)
......
...@@ -59,6 +59,16 @@ class TowerContext(object): ...@@ -59,6 +59,16 @@ class TowerContext(object):
def name(self): def name(self):
return self._name return self._name
# variable_scope name
@property
def vs_name(self):
if self.has_own_variables:
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if self.index > 0:
return self._name
return ""
@property @property
def index(self): def index(self):
if self._name == '': if self._name == '':
...@@ -103,8 +113,8 @@ class TowerContext(object): ...@@ -103,8 +113,8 @@ class TowerContext(object):
self._ctxs = [] self._ctxs = []
if len(self._name): if len(self._name):
if self.has_own_variables: if self.has_own_variables:
# open new variable scopes if self.vs_name:
self._ctxs.append(tf.variable_scope(self._name)) self._ctxs.append(tf.variable_scope(self.vs_name))
else: else:
# use existing variable scope # use existing variable scope
reuse = self.index > 0 or (not self.is_training) reuse = self.index > 0 or (not self.is_training)
......
...@@ -68,9 +68,9 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -68,9 +68,9 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
# GATE_NONE faster? # GATE_NONE faster?
varlist = tf.trainable_variables() varlist = tf.trainable_variables()
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables: if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower # only optimize w.r.t vars in this tower
varlist = [v for v in varlist if v.op.name.startswith(ctx.name + '/')] varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, cost,
var_list=varlist, var_list=varlist,
......
...@@ -244,7 +244,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -244,7 +244,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
super(SyncMultiGPUTrainerReplicated, self)._setup() super(SyncMultiGPUTrainerReplicated, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower] raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
opt = self.model.get_optimizer() # XXX call before build tower to avoid opt under tower scopes.
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
lambda: self._get_cost_and_grad()[1], lambda: self._get_cost_and_grad()[1],
...@@ -252,6 +251,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -252,6 +251,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grads = self._allreduce_grads(grad_list) grads = self._allreduce_grads(grad_list)
train_ops = [] train_ops = []
opt = self.model.get_optimizer()
for idx in range(self.config.nr_tower): for idx in range(self.config.nr_tower):
with tf.device(raw_devices[idx]): with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads] grad_and_vars = [x[idx] for x in grads]
...@@ -272,9 +272,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -272,9 +272,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
post_init_ops = [] post_init_ops = []
for v in global_vars: for v in global_vars:
split_name = v.name.split('/') split_name = v.name.split('/')
if split_name[0] == 'tower0' or not v.name.startswith('tower'): if not v.name.startswith('tower'):
continue continue
split_name[0] = 'tower0' # the master name doesn't have the towerx/ prefix
split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)] copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars') return tf.group(*post_init_ops, name='init_sync_vars')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment