Commit b6370d50 authored by Yuxin Wu's avatar Yuxin Wu

fix multigpu naming problem

parent c08297ff
...@@ -20,9 +20,26 @@ class PeriodicSaver(PeriodicCallback): ...@@ -20,9 +20,26 @@ class PeriodicSaver(PeriodicCallback):
def _before_train(self): def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model') self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=self._get_vars(),
max_to_keep=self.keep_recent, max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq) keep_checkpoint_every_n_hours=self.keep_freq)
def _get_vars(self):
vars = tf.all_variables()
var_dict = {}
for v in vars:
name = v.op.name
if re.match('tower[1-9]', name):
logger.info("Skip {} when saving model.".format(name))
continue
if 'tower0/' in name:
new_name = name.replace('tower0/', '')
logger.info(
"{} renamed to {} when saving model.".format(name, new_name))
name = new_name
var_dict[name] = v
return var_dict
def _trigger_periodic(self): def _trigger_periodic(self):
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
......
...@@ -40,9 +40,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5): ...@@ -40,9 +40,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5):
initializer=tf.constant_initializer(1.0)) initializer=tf.constant_initializer(1.0))
if len(shape) == 2: if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], name='moments', keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
else: else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments', keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay) ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
......
...@@ -118,6 +118,8 @@ class QueueInputTrainer(Trainer): ...@@ -118,6 +118,8 @@ class QueueInputTrainer(Trainer):
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
if i == 0:
cost_var_t0 = cost_var
grad_list.append( grad_list.append(
self.config.optimizer.compute_gradients(cost_var)) self.config.optimizer.compute_gradients(cost_var))
...@@ -129,6 +131,7 @@ class QueueInputTrainer(Trainer): ...@@ -129,6 +131,7 @@ class QueueInputTrainer(Trainer):
del tf.get_collection(k)[:] del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list) grads = QueueInputTrainer._average_grads(grad_list)
cost_var = cost_var_t0
else: else:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment