Commit d44d32d6 authored by Yuxin Wu's avatar Yuxin Wu

organize name scopes in EMA & trainers (#340)

parent 5b310290
......@@ -104,10 +104,13 @@ class Model(GANModelDesc):
with tf.variable_scope('dec'):
recon_pos = self.decoder(hidden_pos)
recon_neg = self.decoder(hidden_neg)
with tf.name_scope('viz'):
summary_image('generated-samples', image_gen)
summary_image('reconstruct-real', recon_pos)
summary_image('reconstruct-fake', recon_neg)
with tf.name_scope('losses'):
L_pos = tf.reduce_mean(tf.abs(recon_pos - image_pos), name='loss_pos')
L_neg = tf.reduce_mean(tf.abs(recon_neg - image_gen), name='loss_neg')
......
......@@ -77,6 +77,7 @@ class GANTrainer(Trainer):
opt = model.get_optimizer()
# by default, run one d_min after one g_min
with tf.name_scope('optimize'):
g_min = opt.minimize(model.g_loss, var_list=model.g_vars, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
......@@ -106,6 +107,7 @@ class SeparateGANTrainer(Trainer):
model.build_graph(input)
opt = model.get_optimizer()
with tf.name_scope('optimize'):
self.d_min = opt.minimize(
model.d_loss, var_list=model.d_vars, name='d_min')
self.g_min = opt.minimize(
......@@ -142,6 +144,7 @@ class MultiGPUGANTrainer(Trainer):
cost_list = MultiGPUTrainerBase.build_on_multi_tower(
config.tower, get_cost, devices)
# simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu)
......
......@@ -155,4 +155,4 @@ if __name__ == '__main__':
if config.nr_tower <= 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainer(config).train()
AsyncMultiGPUTrainer(config).train()
......@@ -146,8 +146,11 @@ class ModelDesc(ModelDescBase):
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
"""
cost = self._get_cost()
return tf.add(cost, regularize_cost_from_collection(),
name='cost_with_regularizer')
reg_cost = regularize_cost_from_collection()
if reg_cost:
return tf.add(cost, reg_cost, name='cost_with_regularizer')
else:
return cost
def _get_cost(self, *args):
return self.cost
......
......@@ -64,7 +64,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
In replicated mode, will only regularize variables within the current tower.
Returns:
a scalar tensor, the regularization loss.
a scalar tensor, the regularization loss, or 0
"""
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
......@@ -73,12 +73,11 @@ def regularize_cost_from_collection(name='regularize_cost'):
# It is only added with variables that are newly created.
if ctx.has_own_variables: # be careful of the first tower (name='')
regularization_losses = ctx.filter_vars_by_vs_name(regularization_losses)
print([k.name for k in regularization_losses])
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(regularization_losses)))
reg_loss = tf.add_n(list(regularization_losses), name=name)
return reg_loss
else:
return tf.constant(0, dtype=tf.float32, name='empty_' + name)
return 0
@layer_register(log_shape=False, use_scope=False)
......
......@@ -174,6 +174,7 @@ class AccumGradOptimizer(ProxyOptimizer):
counter = tf.Variable(
0, name="counter", trainable=False, dtype=tf.int32)
with tf.name_scope('AccumGradOptimizer'):
ops = []
for s, gv in zip(slots, grads_and_vars):
g, v = gv
......@@ -201,7 +202,7 @@ if __name__ == '__main__':
x = tf.get_variable('x', shape=[6])
cost = tf.reduce_sum(tf.abs(x), name='cost')
opt = tf.train.GradientDescentOptimizer(0.01)
# opt = AccumGradOptimizer(opt, 5)
opt = AccumGradOptimizer(opt, 5)
min_op = opt.minimize(cost)
sess = tf.Session()
......
......@@ -7,11 +7,13 @@ import tensorflow as tf
import re
import io
from six.moves import range
from contextlib import contextmanager
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
......@@ -140,6 +142,20 @@ def add_param_summary(*summary_lists):
perform(p, act)
@graph_memoized
def _get_cached_vs(name):
with tf.variable_scope(name) as scope:
return scope
@contextmanager
def _enter_vs_reuse_ns(name):
vs = _get_cached_vs(name)
with tf.variable_scope(vs):
with tf.name_scope(vs.original_name_scope):
yield vs
def add_moving_summary(v, *args, **kwargs):
"""
Enable moving average summary for some tensors.
......@@ -173,18 +189,17 @@ def add_moving_summary(v, *args, **kwargs):
for c in v:
name = re.sub('tower[0-9]+/', '', c.op.name)
with G.colocate_with(c):
with tf.variable_scope('EMA') as vs:
with G.colocate_with(c), tf.name_scope(None):
with _enter_vs_reuse_ns('EMA') as vs:
# will actually create ns EMA_1, EMA_2, etc. tensorflow#6007
ema_var = tf.get_variable(name, shape=c.shape, dtype=c.dtype,
initializer=tf.constant_initializer(), trainable=False)
ns = vs.original_name_scope
# first clear NS to avoid duplicated name in variables
with tf.name_scope(None), tf.name_scope(ns):
with tf.name_scope(ns):
ema_op = moving_averages.assign_moving_average(
ema_var, c, decay,
zero_debias=True, name=name + '_EMA_apply')
with tf.name_scope(None):
tf.summary.scalar(name + '-summary', ema_op)
tf.add_to_collection(coll, ema_op)
# TODO a new collection to summary every step?
......
......@@ -403,6 +403,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
train_ops = []
opt = model.get_optimizer()
with tf.name_scope('async_apply_gradients'):
for i, grad_and_vars in enumerate(zip(*grad_list)):
# Ngpu x 2
v = grad_and_vars[0][1]
......
......@@ -11,7 +11,7 @@ if six.PY2:
else:
import functools
__all__ = ['map_arg', 'memoized', 'shape2d', 'shape4d',
__all__ = ['map_arg', 'memoized', 'graph_memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment