Commit adca1cd7 authored by Yuxin Wu's avatar Yuxin Wu

remove the use of globalns in GANs and just use argparse

parent 8d668003
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils.globvars import globalns as G
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf import tensorflow as tf
...@@ -21,8 +20,6 @@ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/ ...@@ -21,8 +20,6 @@ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/
import DCGAN import DCGAN
G.BATCH = 32
G.Z_DIM = 64
NH = 64 NH = 64
NF = 64 NF = 64
GAMMA = 0.5 GAMMA = 0.5
...@@ -30,7 +27,7 @@ GAMMA = 0.5 ...@@ -30,7 +27,7 @@ GAMMA = 0.5
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, (None, G.SHAPE, G.SHAPE, 3), 'input')] return [InputDesc(tf.float32, (None, args.final_size, args.final_size, 3), 'input')]
@auto_reuse_variable_scope @auto_reuse_variable_scope
def decoder(self, z): def decoder(self, z):
...@@ -80,8 +77,8 @@ class Model(GANModelDesc): ...@@ -80,8 +77,8 @@ class Model(GANModelDesc):
image_pos = inputs[0] image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1 image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([G.BATCH, G.Z_DIM], minval=-1, maxval=1, name='z_train') z = tf.random_uniform([args.batch, args.z_dim], minval=-1, maxval=1, name='z_train')
z = tf.placeholder_with_default(z, [None, G.Z_DIM], name='z') z = tf.placeholder_with_default(z, [None, args.z_dim], name='z')
def summary_image(name, x): def summary_image(name, x):
x = (x + 1.0) * 128.0 x = (x + 1.0) * 128.0
...@@ -133,14 +130,13 @@ class Model(GANModelDesc): ...@@ -133,14 +130,13 @@ class Model(GANModelDesc):
if __name__ == '__main__': if __name__ == '__main__':
args = DCGAN.get_args() args = DCGAN.get_args(default_batch=32, default_z_dim=64)
if args.sample: if args.sample:
DCGAN.sample(Model(), args.load, 'gen/conv4.3/output') DCGAN.sample(Model(), args.load, 'gen/conv4.3/output')
else: else:
assert args.data
logger.auto_set_dir() logger.auto_set_dir()
input = QueueInput(DCGAN.get_data(args.data)) input = QueueInput(DCGAN.get_data())
model = Model() model = Model()
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
if nr_tower == 1: if nr_tower == 1:
......
...@@ -12,7 +12,6 @@ import argparse ...@@ -12,7 +12,6 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import stack_patches from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils.globvars import globalns as opt
import tensorflow as tf import tensorflow as tf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
...@@ -34,15 +33,15 @@ You can also train on other images (just use any directory of jpg files in ...@@ -34,15 +33,15 @@ You can also train on other images (just use any directory of jpg files in
A pretrained model on CelebA is at http://models.tensorpack.com/GAN/ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/
""" """
# global vars
opt.SHAPE = 64
opt.BATCH = 128
opt.Z_DIM = 100
class Model(GANModelDesc): class Model(GANModelDesc):
def __init__(self, shape, batch, z_dim):
self.shape = shape
self.batch = batch
self.zdim = z_dim
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, (None, opt.SHAPE, opt.SHAPE, 3), 'input')] return [InputDesc(tf.float32, (None, self.shape, self.shape, 3), 'input')]
def generator(self, z): def generator(self, z):
""" return an image generated from z""" """ return an image generated from z"""
...@@ -81,8 +80,8 @@ class Model(GANModelDesc): ...@@ -81,8 +80,8 @@ class Model(GANModelDesc):
image_pos = inputs[0] image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1 image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([opt.BATCH, opt.Z_DIM], -1, 1, name='z_train') z = tf.random_uniform([self.batch, self.zdim], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, opt.Z_DIM], name='z') z = tf.placeholder_with_default(z, [None, self.zdim], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
...@@ -103,19 +102,20 @@ class Model(GANModelDesc): ...@@ -103,19 +102,20 @@ class Model(GANModelDesc):
def get_augmentors(): def get_augmentors():
augs = [] augs = []
if opt.load_size: if args.load_size:
augs.append(imgaug.Resize(opt.load_size)) augs.append(imgaug.Resize(args.load_size))
if opt.crop_size: if args.crop_size:
augs.append(imgaug.CenterCrop(opt.crop_size)) augs.append(imgaug.CenterCrop(args.crop_size))
augs.append(imgaug.Resize(opt.SHAPE)) augs.append(imgaug.Resize(args.final_size))
return augs return augs
def get_data(datadir): def get_data():
imgs = glob.glob(datadir + '/*.jpg') assert args.data
imgs = glob.glob(args.data + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = AugmentImageComponent(ds, get_augmentors()) ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, opt.BATCH) ds = BatchData(ds, args.batch)
ds = PrefetchDataZMQ(ds, 5) ds = PrefetchDataZMQ(ds, 5)
return ds return ds
...@@ -126,7 +126,7 @@ def sample(model, model_path, output_name='gen/gen'): ...@@ -126,7 +126,7 @@ def sample(model, model_path, output_name='gen/gen'):
model=model, model=model,
input_names=['z'], input_names=['z'],
output_names=[output_name, 'z']) output_names=[output_name, 'z'])
pred = SimpleDatasetPredictor(pred, RandomZData((100, opt.Z_DIM))) pred = SimpleDatasetPredictor(pred, RandomZData((100, args.z_dim)))
for o in pred.get_result(): for o in pred.get_result():
o = o[0] + 1 o = o[0] + 1
o = o * 128.0 o = o * 128.0
...@@ -135,7 +135,7 @@ def sample(model, model_path, output_name='gen/gen'): ...@@ -135,7 +135,7 @@ def sample(model, model_path, output_name='gen/gen'):
stack_patches(o, nr_row=10, nr_col=10, viz=True) stack_patches(o, nr_row=10, nr_col=10, viz=True)
def get_args(): def get_args(default_batch=128, default_z_dim=100):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
...@@ -143,8 +143,13 @@ def get_args(): ...@@ -143,8 +143,13 @@ def get_args():
parser.add_argument('--data', help='a jpeg directory') parser.add_argument('--data', help='a jpeg directory')
parser.add_argument('--load-size', help='size to load the original images', type=int) parser.add_argument('--load-size', help='size to load the original images', type=int)
parser.add_argument('--crop-size', help='crop the original images', type=int) parser.add_argument('--crop-size', help='crop the original images', type=int)
parser.add_argument(
'--final-size', default=64, type=int,
help='resize to this shape as inputs to network')
parser.add_argument('--z-dim', help='hidden dimension', type=int, default=default_z_dim)
parser.add_argument('--batch', help='batch size', type=int, default=default_batch)
global args
args = parser.parse_args() args = parser.parse_args()
opt.use_argument(args)
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
return args return args
...@@ -152,14 +157,14 @@ def get_args(): ...@@ -152,14 +157,14 @@ def get_args():
if __name__ == '__main__': if __name__ == '__main__':
args = get_args() args = get_args()
M = Model(shape=args.final_size, batch=args.batch, z_dim=args.z_dim)
if args.sample: if args.sample:
sample(Model(), args.load) sample(M, args.load)
else: else:
assert args.data
logger.auto_set_dir() logger.auto_set_dir()
GANTrainer( GANTrainer(
input=QueueInput(get_data(args.data)), input=QueueInput(get_data()),
model=Model()).train_with_defaults( model=M).train_with_defaults(
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf import tensorflow as tf
...@@ -17,10 +16,8 @@ See the docstring in DCGAN.py for usage. ...@@ -17,10 +16,8 @@ See the docstring in DCGAN.py for usage.
""" """
# Don't want to mix two examples together, but want to reuse the code. # Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN, and change the batch size & model # So here just import stuff from DCGAN.
import DCGAN import DCGAN
G.BATCH = 64
G.Z_DIM = 128
class Model(DCGAN.Model): class Model(DCGAN.Model):
...@@ -47,8 +44,8 @@ class Model(DCGAN.Model): ...@@ -47,8 +44,8 @@ class Model(DCGAN.Model):
image_pos = inputs[0] image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1 image_pos = image_pos / 128.0 - 1
z = tf.random_normal([G.BATCH, G.Z_DIM], name='z_train') z = tf.random_normal([self.batch, self.zdim], name='z_train')
z = tf.placeholder_with_default(z, [None, G.Z_DIM], name='z') z = tf.placeholder_with_default(z, [None, self.zdim], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
...@@ -56,7 +53,7 @@ class Model(DCGAN.Model): ...@@ -56,7 +53,7 @@ class Model(DCGAN.Model):
image_gen = self.generator(z) image_gen = self.generator(z)
tf.summary.image('generated-samples', image_gen, max_outputs=30) tf.summary.image('generated-samples', image_gen, max_outputs=30)
alpha = tf.random_uniform(shape=[G.BATCH, 1, 1, 1], alpha = tf.random_uniform(shape=[self.batch, 1, 1, 1],
minval=0., maxval=1., name='alpha') minval=0., maxval=1., name='alpha')
interp = image_pos + alpha * (image_gen - image_pos) interp = image_pos + alpha * (image_gen - image_pos)
...@@ -86,15 +83,15 @@ class Model(DCGAN.Model): ...@@ -86,15 +83,15 @@ class Model(DCGAN.Model):
if __name__ == '__main__': if __name__ == '__main__':
args = DCGAN.get_args() args = DCGAN.get_args(default_batch=64, default_z_dim=128)
M = Model(shape=args.final_size, batch=args.batch, z_dim=args.z_dim)
if args.sample: if args.sample:
DCGAN.sample(Model(), args.load) DCGAN.sample(M, args.load)
else: else:
assert args.data
logger.auto_set_dir() logger.auto_set_dir()
SeparateGANTrainer( SeparateGANTrainer(
QueueInput(DCGAN.get_data(args.data)), QueueInput(DCGAN.get_data()),
Model(), g_period=6).train_with_defaults( M, g_period=6).train_with_defaults(
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
import tensorflow as tf import tensorflow as tf
from GAN import SeparateGANTrainer from GAN import SeparateGANTrainer
...@@ -15,9 +14,8 @@ See the docstring in DCGAN.py for usage. ...@@ -15,9 +14,8 @@ See the docstring in DCGAN.py for usage.
""" """
# Don't want to mix two examples together, but want to reuse the code. # Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN, and change the batch size & model # So here just import stuff from DCGAN
import DCGAN import DCGAN
G.BATCH = 64
class Model(DCGAN.Model): class Model(DCGAN.Model):
...@@ -64,20 +62,19 @@ class ClipCallback(Callback): ...@@ -64,20 +62,19 @@ class ClipCallback(Callback):
if __name__ == '__main__': if __name__ == '__main__':
args = DCGAN.get_args() args = DCGAN.get_args(default_batch=64)
M = Model(shape=args.final_size, batch=args.batch, z_dim=args.z_dim)
if args.sample: if args.sample:
DCGAN.sample(Model(), args.load) DCGAN.sample(M, args.load)
else: else:
assert args.data
logger.auto_set_dir() logger.auto_set_dir()
# The original code uses a different schedule, but this seems to work well. # The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G # Train 1 D after 2 G
SeparateGANTrainer( SeparateGANTrainer(
input=QueueInput(DCGAN.get_data(args.data)), input=QueueInput(DCGAN.get_data()),
model=Model(), model=M, d_period=3).train_with_defaults(
d_period=3).train_with_defaults(
callbacks=[ModelSaver(), ClipCallback()], callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=200, max_epoch=200,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment