Commit 491e0cd9 authored by Yuxin Wu's avatar Yuxin Wu

use globalns in DCGAN

parent 2a0e96e0
...@@ -13,6 +13,7 @@ import cv2 ...@@ -13,6 +13,7 @@ import cv2
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses from GAN import GANTrainer, RandomZData, build_GAN_losses
...@@ -27,12 +28,13 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference ...@@ -27,12 +28,13 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
./DCGAN-CelebA.py --load model.tfmodel --sample ./DCGAN-CelebA.py --load model.tfmodel --sample
""" """
SHAPE = 64 CFG.SHAPE = 64
BATCH = 128 CFG.BATCH = 128
CFG.Z_DIM = 100
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, 3), 'input') ] return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input') ]
def generator(self, z): def generator(self, z):
""" return a image generated from z""" """ return a image generated from z"""
...@@ -66,8 +68,8 @@ class Model(ModelDesc): ...@@ -66,8 +68,8 @@ class Model(ModelDesc):
image_pos = input_vars[0] image_pos = input_vars[0]
image_pos = image_pos / 128.0 - 1 image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train') z = tf.random_uniform([CFG.BATCH, CFG.Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, 100], name='z') z = tf.placeholder_with_default(z, [None, CFG.Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
...@@ -85,12 +87,12 @@ class Model(ModelDesc): ...@@ -85,12 +87,12 @@ class Model(ModelDesc):
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')] self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def get_data(): def get_data():
datadir = args.data datadir = CFG.data
imgs = glob.glob(datadir + '/*.jpg') imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [ imgaug.CenterCrop(110), imgaug.Resize(64) ] augs = [ imgaug.CenterCrop(140), imgaug.Resize(64) ]
ds = AugmentImageComponent(ds, augs) ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, BATCH) ds = BatchData(ds, CFG.BATCH)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
...@@ -149,8 +151,8 @@ if __name__ == '__main__': ...@@ -149,8 +151,8 @@ if __name__ == '__main__':
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--vec', action='store_true', help='run vec arithmetic demo') parser.add_argument('--vec', action='store_true', help='run vec arithmetic demo')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset') parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
global args
args = parser.parse_args() args = parser.parse_args()
use_global_argument(args)
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample: if args.sample:
...@@ -162,5 +164,4 @@ if __name__ == '__main__': ...@@ -162,5 +164,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
GANTrainer(config).train() GANTrainer(config, g_vs_d=1).train()
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six import six
import argparse
__all__ = ['globalns'] __all__ = ['globalns', 'use_global_argument']
if six.PY2: if six.PY2:
class NS: pass class NS: pass
...@@ -14,3 +15,12 @@ else: ...@@ -14,3 +15,12 @@ else:
NS = types.SimpleNamespace NS = types.SimpleNamespace
globalns = NS() globalns = NS()
def use_global_argument(args):
"""
Add the content of argparse.Namespace to globalns
:param args: Argument
"""
assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)):
setattr(globalns, k, v)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment