Commit bf57d0af authored by Yuxin Wu's avatar Yuxin Wu

import namespace cleanup in tfutils

parent 23e4f928
...@@ -14,6 +14,7 @@ from tensorpack import * ...@@ -14,6 +14,7 @@ from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import * from tensorpack.tfutils.distributions import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils import optimizer, summary
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
import argparse import argparse
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G from tensorpack.utils.globvars import globalns as G
import tensorflow as tf import tensorflow as tf
......
...@@ -12,7 +12,8 @@ import os ...@@ -12,7 +12,8 @@ import os
import sys import sys
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -72,7 +73,7 @@ class Model(ModelDesc): ...@@ -72,7 +73,7 @@ class Model(ModelDesc):
costs = [] costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]): for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx + 1)) output = tf.nn.sigmoid(b, name='output{}'.format(idx + 1))
xentropy = class_balanced_sigmoid_cross_entropy( xentropy = symbf.class_balanced_sigmoid_cross_entropy(
b, edgemap, b, edgemap,
name='xentropy{}'.format(idx + 1)) name='xentropy{}'.format(idx + 1))
costs.append(xentropy) costs.append(xentropy)
...@@ -93,7 +94,7 @@ class Model(ModelDesc): ...@@ -93,7 +94,7 @@ class Model(ModelDesc):
add_moving_summary(costs + [wrong, self.cost]) add_moving_summary(costs + [wrong, self.cost])
def _get_optimizer(self): def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 3e-5, summary=True) lr = symbf.get_scalar_var('learning_rate', 3e-5, summary=True)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3) opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors( return optimizer.apply_grad_processors(
opt, [gradproc.ScaleGradient( opt, [gradproc.ScaleGradient(
......
...@@ -9,6 +9,7 @@ import argparse ...@@ -9,6 +9,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import * from tensorpack.tfutils.gradproc import *
from tensorpack.tfutils import optimizer, summary
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.fs import download, get_dataset_path from tensorpack.utils.fs import download, get_dataset_path
from tensorpack.utils.argtools import memoized_ignoreargs from tensorpack.utils.argtools import memoized_ignoreargs
......
...@@ -12,6 +12,7 @@ import multiprocessing ...@@ -12,6 +12,7 @@ import multiprocessing
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
......
...@@ -11,6 +11,7 @@ import sys ...@@ -11,6 +11,7 @@ import sys
import argparse import argparse
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import sesscreate, optimizer, summary
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
IMAGE_SIZE = 42 IMAGE_SIZE = 42
...@@ -78,7 +79,7 @@ class Model(ModelDesc): ...@@ -78,7 +79,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = symbolic_functions.prediction_incorrect(logits, label) wrong = symbf.prediction_incorrect(logits, label)
summary.add_moving_summary(tf.reduce_mean(wrong, name='train_error')) summary.add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
wd_cost = tf.multiply(1e-5, regularize_cost('fc.*/W', tf.nn.l2_loss), wd_cost = tf.multiply(1e-5, regularize_cost('fc.*/W', tf.nn.l2_loss),
......
...@@ -13,6 +13,7 @@ about 0.6% validation error after 30 epochs. ...@@ -13,6 +13,7 @@ about 0.6% validation error after 30 epochs.
# Just import everything into current namespace # Just import everything into current namespace
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import summary
import tensorflow as tf import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import os import os
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import prediction_incorrect
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
import tensorflow as tf import tensorflow as tf
......
...@@ -33,6 +33,5 @@ for _, module_name, _ in iter_modules( ...@@ -33,6 +33,5 @@ for _, module_name, _ in iter_modules(
continue continue
if module_name in _TO_IMPORT: if module_name in _TO_IMPORT:
_global_import(module_name) # import the content to tfutils.* _global_import(module_name) # import the content to tfutils.*
else: __all__.extend(['sessinit', 'summary', 'optimizer',
__all__.append(module_name) # import the module separately 'sesscreate', 'gradproc', 'varreplace'])
__all__.extend(['sessinit', 'gradproc'])
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np import numpy as np
__all__ = [] # __all__ = ['get_scalar_var']
# this function exists for backwards-compatibilty # this function exists for backwards-compatibilty
......
...@@ -7,7 +7,6 @@ import six ...@@ -7,7 +7,6 @@ import six
import os import os
import pprint import pprint
import tensorflow as tf import tensorflow as tf
from collections import defaultdict
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
...@@ -47,9 +46,7 @@ class SessionUpdate(object): ...@@ -47,9 +46,7 @@ class SessionUpdate(object):
vars_to_update: a collection of variables to update vars_to_update: a collection of variables to update
""" """
self.sess = sess self.sess = sess
self.name_map = defaultdict(list) self.name_map = {v.name: v for v in vars_to_update}
for v in vars_to_update:
self.name_map[v.name].append(v)
@staticmethod @staticmethod
def load_value_to_var(var, val, strict=False): def load_value_to_var(var, val, strict=False):
...@@ -108,8 +105,8 @@ class SessionUpdate(object): ...@@ -108,8 +105,8 @@ class SessionUpdate(object):
with self.sess.as_default(): with self.sess.as_default():
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
assert name in self.name_map assert name in self.name_map
for v in self.name_map[name]: v = self.name_map[name]
SessionUpdate.load_value_to_var(v, value) SessionUpdate.load_value_to_var(v, value)
def dump_session_params(path): def dump_session_params(path):
...@@ -168,6 +165,9 @@ def dump_chkpt_vars(model_path): ...@@ -168,6 +165,9 @@ def dump_chkpt_vars(model_path):
Args: Args:
model_path(str): path to a checkpoint. model_path(str): path to a checkpoint.
Returns:
dict: a name:value dict
""" """
model_path = get_checkpoint_path(model_path) model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
...@@ -199,6 +199,6 @@ def is_training_name(name): ...@@ -199,6 +199,6 @@ def is_training_name(name):
return True return True
if name.endswith('/Adagrad'): if name.endswith('/Adagrad'):
return True return True
if 'EMA_summary/' in name: if name.startswith('/EMA'):
return True return True
return False return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment