Commit 8a3d0d60 authored by Yuxin Wu's avatar Yuxin Wu

small update

parent 1b06a41a
......@@ -90,7 +90,7 @@ class Model(ModelDesc):
.Conv2D('conv3', out_channel=64, kernel_shape=3)
# the original arch is 2x faster
# .Conv2D('conv0', image, out_channel=32, kernel_shape=8, stride=4)
# .Conv2D('conv0', out_channel=32, kernel_shape=8, stride=4)
# .Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
# .Conv2D('conv2', out_channel=64, kernel_shape=3)
......
......@@ -42,6 +42,11 @@ LAMBDA = 100
NF = 64 # number of filter
def BNLReLU(x, name=None):
x = BatchNorm('bn', x)
return LeakyReLU(x, name=name)
class Model(GANModelDesc):
def _get_inputs(self):
SHAPE = 256
......@@ -54,8 +59,7 @@ class Model(GANModelDesc):
with argscope(BatchNorm, use_local_stat=True), \
argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing
with argscope(Conv2D, kernel_shape=4, stride=2,
nl=lambda x, name: LeakyReLU(BatchNorm('bn', x), name=name)):
with argscope(Conv2D, kernel_shape=4, stride=2, nl=BNLReLU):
e1 = Conv2D('conv1', imgs, NF, nl=LeakyReLU)
e2 = Conv2D('conv2', e1, NF * 2)
e3 = Conv2D('conv3', e2, NF * 4)
......@@ -89,16 +93,13 @@ class Model(GANModelDesc):
def discriminator(self, inputs, outputs):
""" return a (b, 1) logits"""
l = tf.concat([inputs, outputs], 3)
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2):
with argscope(Conv2D, kernel_shape=4, stride=2, nl=BNLReLU):
l = (LinearWrap(l)
.Conv2D('conv0', NF, nl=LeakyReLU)
.Conv2D('conv1', NF * 2)
.BatchNorm('bn1').LeakyReLU()
.Conv2D('conv2', NF * 4)
.BatchNorm('bn2').LeakyReLU()
.Conv2D('conv3', NF * 8, stride=1, padding='VALID')
.BatchNorm('bn3').LeakyReLU()
.Conv2D('convlast', 1, stride=1, padding='VALID')())
.Conv2D('convlast', 1, stride=1, padding='VALID', nl=tf.identity)())
return l
def _build_graph(self, inputs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment