Commit bf57d0af authored by Yuxin Wu's avatar Yuxin Wu

import namespace cleanup in tfutils

parent 23e4f928
......@@ -14,6 +14,7 @@ from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils import optimizer, summary
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc
......
......@@ -7,6 +7,7 @@ import os
import argparse
from tensorpack import *
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
import tensorflow as tf
......
......@@ -12,7 +12,8 @@ import os
import sys
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import *
......@@ -72,7 +73,7 @@ class Model(ModelDesc):
costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx + 1))
xentropy = class_balanced_sigmoid_cross_entropy(
xentropy = symbf.class_balanced_sigmoid_cross_entropy(
b, edgemap,
name='xentropy{}'.format(idx + 1))
costs.append(xentropy)
......@@ -93,7 +94,7 @@ class Model(ModelDesc):
add_moving_summary(costs + [wrong, self.cost])
def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 3e-5, summary=True)
lr = symbf.get_scalar_var('learning_rate', 3e-5, summary=True)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors(
opt, [gradproc.ScaleGradient(
......
......@@ -9,6 +9,7 @@ import argparse
from tensorpack import *
from tensorpack.tfutils.gradproc import *
from tensorpack.tfutils import optimizer, summary
from tensorpack.utils import logger
from tensorpack.utils.fs import download, get_dataset_path
from tensorpack.utils.argtools import memoized_ignoreargs
......
......@@ -12,6 +12,7 @@ import multiprocessing
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import *
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......
......@@ -11,6 +11,7 @@ import sys
import argparse
from tensorpack import *
from tensorpack.tfutils import sesscreate, optimizer, summary
import tensorpack.tfutils.symbolic_functions as symbf
IMAGE_SIZE = 42
......@@ -78,7 +79,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = symbolic_functions.prediction_incorrect(logits, label)
wrong = symbf.prediction_incorrect(logits, label)
summary.add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
wd_cost = tf.multiply(1e-5, regularize_cost('fc.*/W', tf.nn.l2_loss),
......
......@@ -13,6 +13,7 @@ about 0.6% validation error after 30 epochs.
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary
import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf
......
......@@ -8,7 +8,7 @@ import numpy as np
import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.symbolic_functions import prediction_incorrect
from tensorpack.tfutils.summary import *
import tensorflow as tf
......
......@@ -33,6 +33,5 @@ for _, module_name, _ in iter_modules(
continue
if module_name in _TO_IMPORT:
_global_import(module_name) # import the content to tfutils.*
else:
__all__.append(module_name) # import the module separately
__all__.extend(['sessinit', 'gradproc'])
__all__.extend(['sessinit', 'summary', 'optimizer',
'sesscreate', 'gradproc', 'varreplace'])
......@@ -6,7 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager
import numpy as np
__all__ = []
# __all__ = ['get_scalar_var']
# this function exists for backwards-compatibilty
......
......@@ -7,7 +7,6 @@ import six
import os
import pprint
import tensorflow as tf
from collections import defaultdict
import numpy as np
from ..utils import logger
from .common import get_op_tensor_name
......@@ -47,9 +46,7 @@ class SessionUpdate(object):
vars_to_update: a collection of variables to update
"""
self.sess = sess
self.name_map = defaultdict(list)
for v in vars_to_update:
self.name_map[v.name].append(v)
self.name_map = {v.name: v for v in vars_to_update}
@staticmethod
def load_value_to_var(var, val, strict=False):
......@@ -108,7 +105,7 @@ class SessionUpdate(object):
with self.sess.as_default():
for name, value in six.iteritems(prms):
assert name in self.name_map
for v in self.name_map[name]:
v = self.name_map[name]
SessionUpdate.load_value_to_var(v, value)
......@@ -168,6 +165,9 @@ def dump_chkpt_vars(model_path):
Args:
model_path(str): path to a checkpoint.
Returns:
dict: a name:value dict
"""
model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path)
......@@ -199,6 +199,6 @@ def is_training_name(name):
return True
if name.endswith('/Adagrad'):
return True
if 'EMA_summary/' in name:
if name.startswith('/EMA'):
return True
return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment