Commit 81d5fbd8 authored by Yuxin Wu's avatar Yuxin Wu

[SuperResolution] closer to paper's settings (#541)

parent 6b10019e
......@@ -14,29 +14,27 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from data_sampler import ImageDecode
from GAN import MultiGPUGANTrainer, GANModelDesc
from GAN import SeparateGANTrainer, GANModelDesc
Reduction = tf.losses.Reduction
BATCH_SIZE = 6
BATCH_SIZE = 16
CHANNELS = 3
SHAPE_LR = 32
NF = 64
VGG_MEAN = np.array([123.68, 116.779, 103.939]) # RGB
GAN_FACTOR_PARAMETER = 2.
def normalize(v):
assert isinstance(v, tf.Tensor)
v.get_shape().assert_has_rank(4)
dim = v.get_shape().as_list()
return v / (dim[1] * dim[2] * dim[3])
return v / tf.reduce_mean(v, axis=[1, 2, 3], keep_dims=True)
def gram_matrix(v):
assert isinstance(v, tf.Tensor)
v.get_shape().assert_has_rank(4)
dim = v.get_shape().as_list()
v = normalize(v)
v = tf.reshape(v, [-1, dim[1] * dim[2], dim[3]])
return tf.matmul(v, v, transpose_a=True)
......@@ -49,14 +47,17 @@ class Model(GANModelDesc):
self.width = width
def _get_inputs(self):
# mean-subtracted images
return [InputDesc(tf.float32, (None, self.height * 1, self.width * 1, CHANNELS), 'Ilr'),
InputDesc(tf.float32, (None, self.height * 4, self.width * 4, CHANNELS), 'Ihr')]
def _build_graph(self, inputs):
ctx = get_current_tower_context()
Ilr, Ihr = inputs[0] / 255.0, inputs[1] / 255.0
Ibicubic = tf.image.resize_bicubic(Ilr, [4 * self.height, 4 * self.width])
Ibicubic = tf.image.resize_bicubic(
Ilr, [4 * self.height, 4 * self.width], align_corners=True,
name='bicubic_baseline') # (0,1)
VGG_MEAN_TENSOR = tf.constant(VGG_MEAN, dtype=tf.float32)
def resnet_block(x, name):
with tf.variable_scope(name):
......@@ -66,10 +67,11 @@ class Model(GANModelDesc):
def upsample(x, factor=2):
_, h, w, _ = x.get_shape().as_list()
x = tf.image.resize_nearest_neighbor(x, [factor * h, factor * w])
x = tf.image.resize_nearest_neighbor(x, [factor * h, factor * w], align_corners=True)
return x
def generator(x, Ibicubic):
x = x - VGG_MEAN_TENSOR / 255.0
with argscope(Conv2D, kernel_shape=3, stride=1, nl=tf.nn.relu):
x = Conv2D('conv1', x, NF)
for i in range(10):
......@@ -81,10 +83,11 @@ class Model(GANModelDesc):
x = Conv2D('conv_post_3', x, NF)
Ires = Conv2D('conv_post_4', x, 3, nl=tf.identity)
Iest = tf.add(Ibicubic, Ires, name='Iest')
return Iest
return Iest # [0,1]
@auto_reuse_variable_scope
def discriminator(x):
x = x - VGG_MEAN_TENSOR / 255.0
with argscope(Conv2D, kernel_shape=3, stride=1, nl=tf.nn.leaky_relu):
x = Conv2D('conv0', x, 32)
x = Conv2D('conv0b', x, 32, stride=2)
......@@ -104,7 +107,8 @@ class Model(GANModelDesc):
def additional_losses(a, b):
with tf.variable_scope('VGG19'):
x = tf.concat([a, b], axis=0)
x = tf.reshape(x, [2 * BATCH_SIZE, 128, 128, 3])
x = tf.reshape(x, [2 * BATCH_SIZE, SHAPE_LR * 4, SHAPE_LR * 4, 3]) * 255.0
x = x - VGG_MEAN_TENSOR
# VGG 19
with varreplace.freeze_variables():
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu):
......@@ -132,6 +136,8 @@ class Model(GANModelDesc):
# perceptual loss
with tf.name_scope('perceptual_loss'):
pool2 = normalize(pool2)
pool5 = normalize(pool5)
phi_a_1, phi_b_1 = tf.split(pool2, 2, axis=0)
phi_a_2, phi_b_2 = tf.split(pool5, 2, axis=0)
......@@ -143,23 +149,23 @@ class Model(GANModelDesc):
# texture loss
with tf.name_scope('texture_loss'):
def texture_loss(x, p=16):
x = normalize(x)
_, h, w, c = x.get_shape().as_list()
x = normalize(x)
assert h % p == 0 and w % p == 0
logger.info('Create texture loss for layer {} with shape {}'.format(x.name, x.get_shape()))
x = tf.space_to_batch_nd(x, [p, p], [[0, 0], [0, 0]])
x = tf.reshape(x, [p, p, -1, h // p, w // p, c])
x = tf.transpose(x, [2, 3, 4, 0, 1, 5])
patches_a, patches_b = tf.split(x, 2) # each is b,h/p,w/p,p,p,c
x = tf.space_to_batch_nd(x, [p, p], [[0, 0], [0, 0]]) # [b * ?, h/p, w/p, c]
x = tf.reshape(x, [p, p, -1, h // p, w // p, c]) # [p, p, b, h/p, w/p, c]
x = tf.transpose(x, [2, 3, 4, 0, 1, 5]) # [b * ?, p, p, c]
patches_a, patches_b = tf.split(x, 2, axis=0) # each is b,h/p,w/p,p,p,c
patches_a = tf.reshape(patches_a, [-1, p, p, c])
patches_b = tf.reshape(patches_b, [-1, p, p, c])
patches_a = tf.reshape(patches_a, [-1, p, p, c]) # [b * ?, p, p, c]
patches_b = tf.reshape(patches_b, [-1, p, p, c]) # [b * ?, p, p, c]
return tf.losses.mean_squared_error(
gram_matrix(patches_a),
gram_matrix(patches_b),
reduction=Reduction.SUM
) * (1.0 / BATCH_SIZE)
reduction=Reduction.MEAN
)
texture_loss_conv1_1 = tf.identity(texture_loss(conv1_1), name='normalized_conv1_1')
texture_loss_conv2_1 = tf.identity(texture_loss(conv2_1), name='normalized_conv2_1')
......@@ -171,8 +177,7 @@ class Model(GANModelDesc):
fake_hr = generator(Ilr, Ibicubic)
real_hr = Ihr
VGG_MEAN_TENSOR = tf.constant(VGG_MEAN, dtype=tf.float32)
tf.add(fake_hr, VGG_MEAN_TENSOR / 255.0, name='prediction')
tf.multiply(fake_hr, 255.0, name='prediction')
if ctx.is_training:
with tf.variable_scope('discrim'):
......@@ -185,18 +190,19 @@ class Model(GANModelDesc):
with tf.name_scope('additional_losses'):
# see table 2 from appendix
loss = []
loss.append(tf.multiply(1., self.g_loss, name="loss_LA"))
loss.append(tf.multiply(GAN_FACTOR_PARAMETER, self.g_loss, name="loss_LA"))
loss.append(tf.multiply(2e-1, additional_losses[0], name="loss_LP1"))
loss.append(tf.multiply(2e-2, additional_losses[1], name="loss_LP2"))
loss.append(tf.multiply(3e-7, additional_losses[2], name="loss_LT1"))
loss.append(tf.multiply(1e-6, additional_losses[3], name="loss_LT2"))
loss.append(tf.multiply(1e-6, additional_losses[4], name="loss_LT3"))
self.g_loss = self.g_loss + tf.add_n(loss, name='total_g_loss')
add_moving_summary(self.g_loss, *loss)
self.g_loss = tf.add_n(loss, name='total_g_loss')
self.d_loss = tf.multiply(self.d_loss, GAN_FACTOR_PARAMETER, name='d_loss')
add_moving_summary(self.g_loss, self.d_loss, *loss)
# visualization
viz = (tf.concat([Ibicubic, fake_hr, real_hr], 2)) * 255. + VGG_MEAN_TENSOR
viz = (tf.concat([Ibicubic, fake_hr, real_hr], 2)) * 255.
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.summary.image('input,fake,real', viz,
max_outputs=max(30, BATCH_SIZE))
......@@ -207,8 +213,7 @@ class Model(GANModelDesc):
lr = tf.get_variable(
'learning_rate', initializer=1e-4, trainable=False)
opt = tf.train.AdamOptimizer(lr)
gradprocs = [gradproc.ScaleGradient([('discrim/*', 0.3)])]
return optimizer.apply_grad_processors(opt, gradprocs)
return opt
def apply(model_path, lowres_path="", output_path='.'):
......@@ -217,7 +222,6 @@ def apply(model_path, lowres_path="", output_path='.'):
lr = cv2.imread(lowres_path).astype(np.float32)
baseline = cv2.resize(lr, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_CUBIC)
LR_SIZE_H, LR_SIZE_W = lr.shape[:2]
lr -= VGG_MEAN
predict_func = OfflinePredictor(PredictConfig(
model=Model(LR_SIZE_H, LR_SIZE_W),
......@@ -226,7 +230,7 @@ def apply(model_path, lowres_path="", output_path='.'):
output_names=['prediction']))
pred = predict_func(lr[None, ...])
p = np.clip(pred[0][0, ...] * 255, 0, 255)
p = np.clip(pred[0][0, ...], 0, 255)
cv2.imwrite(os.path.join(output_path, "predition.png"), p)
cv2.imwrite(os.path.join(output_path, "baseline.png"), baseline)
......@@ -238,8 +242,7 @@ def get_data(lmdb):
augmentors = [imgaug.RandomCrop(128),
imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: x - VGG_MEAN)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_AREA), x[0]])
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 8)
ds = BatchData(ds, BATCH_SIZE)
return ds
......@@ -275,16 +278,14 @@ if __name__ == '__main__':
nr_tower = max(get_nr_gpu(), 1)
data = QueueInput(get_data(args.lmdb))
model = Model()
if nr_tower == 1:
trainer = GANTrainer(data, model)
else:
trainer = MultiGPUGANTrainer(nr_tower, data, model)
trainer = SeparateGANTrainer(data, model, d_period=3)
trainer.train_with_defaults(
callbacks=[
ModelSaver(keep_checkpoint_every_n_hours=2)
],
session_init=session_init,
steps_per_epoch=data.size(),
steps_per_epoch=data.size() // 4,
max_epoch=2000
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment