Commit b6370d50 authored by Yuxin Wu's avatar Yuxin Wu

fix multigpu naming problem

parent c08297ff
......@@ -20,9 +20,26 @@ class PeriodicSaver(PeriodicCallback):
def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
var_list=self._get_vars(),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _get_vars(self):
vars = tf.all_variables()
var_dict = {}
for v in vars:
name = v.op.name
if re.match('tower[1-9]', name):
logger.info("Skip {} when saving model.".format(name))
continue
if 'tower0/' in name:
new_name = name.replace('tower0/', '')
logger.info(
"{} renamed to {} when saving model.".format(name, new_name))
name = new_name
var_dict[name] = v
return var_dict
def _trigger_periodic(self):
self.saver.save(
tf.get_default_session(),
......
......@@ -40,9 +40,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5):
initializer=tf.constant_initializer(1.0))
if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], name='moments', keep_dims=False)
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments', keep_dims=False)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var])
......
......@@ -118,6 +118,8 @@ class QueueInputTrainer(Trainer):
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
if i == 0:
cost_var_t0 = cost_var
grad_list.append(
self.config.optimizer.compute_gradients(cost_var))
......@@ -129,6 +131,7 @@ class QueueInputTrainer(Trainer):
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list)
cost_var = cost_var_t0
else:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment