Commit 491e0cd9 authored by Yuxin Wu's avatar Yuxin Wu

use globalns in DCGAN

parent 2a0e96e0
......@@ -13,6 +13,7 @@ import cv2
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses
......@@ -27,12 +28,13 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
./DCGAN-CelebA.py --load model.tfmodel --sample
"""
SHAPE = 64
BATCH = 128
CFG.SHAPE = 64
CFG.BATCH = 128
CFG.Z_DIM = 100
class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, 3), 'input') ]
return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input') ]
def generator(self, z):
""" return a image generated from z"""
......@@ -66,8 +68,8 @@ class Model(ModelDesc):
image_pos = input_vars[0]
image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, 100], name='z')
z = tf.random_uniform([CFG.BATCH, CFG.Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, CFG.Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)):
......@@ -85,12 +87,12 @@ class Model(ModelDesc):
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def get_data():
datadir = args.data
datadir = CFG.data
imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [ imgaug.CenterCrop(110), imgaug.Resize(64) ]
augs = [ imgaug.CenterCrop(140), imgaug.Resize(64) ]
ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, BATCH)
ds = BatchData(ds, CFG.BATCH)
ds = PrefetchDataZMQ(ds, 1)
return ds
......@@ -149,8 +151,8 @@ if __name__ == '__main__':
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--vec', action='store_true', help='run vec arithmetic demo')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
global args
args = parser.parse_args()
use_global_argument(args)
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
......@@ -162,5 +164,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
GANTrainer(config, g_vs_d=1).train()
......@@ -4,8 +4,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
import argparse
__all__ = ['globalns']
__all__ = ['globalns', 'use_global_argument']
if six.PY2:
class NS: pass
......@@ -14,3 +15,12 @@ else:
NS = types.SimpleNamespace
globalns = NS()
def use_global_argument(args):
"""
Add the content of argparse.Namespace to globalns
:param args: Argument
"""
assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)):
setattr(globalns, k, v)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment