Commit 66726a8f authored by Yuxin Wu's avatar Yuxin Wu

move use_global_argument into ns

parent 901727fc
...@@ -13,7 +13,6 @@ from tensorpack.utils.viz import * ...@@ -13,7 +13,6 @@ from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils.globvars import globalns as opt from tensorpack.utils.globvars import globalns as opt
from tensorpack.utils.globvars import use_global_argument
import tensorflow as tf import tensorflow as tf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
...@@ -153,7 +152,7 @@ def get_args(): ...@@ -153,7 +152,7 @@ def get_args():
parser.add_argument('--load-size', help='size to load the original images', type=int) parser.add_argument('--load-size', help='size to load the original images', type=int)
parser.add_argument('--crop-size', help='crop the original images', type=int) parser.add_argument('--crop-size', help='crop the original images', type=int)
args = parser.parse_args() args = parser.parse_args()
use_global_argument(args) opt.use_argument(args)
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
return args return args
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import six import six
import argparse import argparse
__all__ = ['globalns', 'use_global_argument'] __all__ = ['globalns']
if six.PY2: if six.PY2:
class NS: class NS:
...@@ -15,16 +15,18 @@ else: ...@@ -15,16 +15,18 @@ else:
import types import types
NS = types.SimpleNamespace NS = types.SimpleNamespace
globalns = NS()
class MyNS(NS):
def use_global_argument(args): def use_argument(self, args):
""" """
Add the content of :class:`argparse.Namespace` to globalns. Add the content of :class:`argparse.Namespace` to this ns.
Args: Args:
args (argparse.Namespace): arguments args (argparse.Namespace): arguments
""" """
assert isinstance(args, argparse.Namespace), type(args) assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)): for k, v in six.iteritems(vars(args)):
setattr(globalns, k, v) setattr(self, k, v)
globalns = MyNS()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment