Commit 06ea1c0a authored by Yuxin Wu's avatar Yuxin Wu

api docs for tfutils/

parent bbf41d9e
...@@ -67,8 +67,7 @@ extensions = [ ...@@ -67,8 +67,7 @@ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
#'sphinx.ext.coverage', #'sphinx.ext.coverage',
#'sphinx.ext.mathjax', 'sphinx.ext.mathjax',
'sphinx.ext.mathbase',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
] ]
......
...@@ -132,8 +132,8 @@ class Model(ModelDesc): ...@@ -132,8 +132,8 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value), self.cost = tf.reduce_mean(symbf.huber_loss(
tf.cast(BATCH_SIZE, tf.float32), name='cost') target - pred_action_value), name='cost')
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']), summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W ('fc.*/W', ['histogram', 'rms'])) # monitor all W
add_moving_summary(self.cost) add_moving_summary(self.cost)
......
...@@ -88,7 +88,7 @@ class Model(ModelDesc): ...@@ -88,7 +88,7 @@ class Model(ModelDesc):
def get_gradient_processor(self): def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5)[0][0]), return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5)[0][0]),
ScaleGradient([('STN.*', 0.1)]), SummaryGradient()] ScaleGradient(('STN.*', 0.1)), SummaryGradient()]
def get_data(isTrain): def get_data(isTrain):
......
...@@ -34,5 +34,4 @@ for _, module_name, _ in walk_packages( ...@@ -34,5 +34,4 @@ for _, module_name, _ in walk_packages(
continue continue
if module_name in _TO_IMPORT: if module_name in _TO_IMPORT:
_global_import(module_name) _global_import(module_name)
if module_name != 'common': __all__.extend(['sessinit', 'gradproc'])
__all__.append(module_name)
...@@ -14,13 +14,30 @@ _ArgScopeStack = [] ...@@ -14,13 +14,30 @@ _ArgScopeStack = []
@contextmanager @contextmanager
def argscope(layers, **param): def argscope(layers, **kwargs):
"""
Args:
layers (list or layer): layer or list of layers to apply the arguments.
Returns:
a context where all appearance of these layer will by default have the
arguments specified by kwargs.
Example:
.. code-block:: python
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
x = Conv2D('conv0', x)
x = Conv2D('conv1', x)
x = Conv2D('conv2', x, out_channel=64) # override argscope
"""
if not isinstance(layers, list): if not isinstance(layers, list):
layers = [layers] layers = [layers]
def _check_args_exist(l): def _check_args_exist(l):
args = inspect.getargspec(l).args args = inspect.getargspec(l).args
for k, v in six.iteritems(param): for k, v in six.iteritems(kwargs):
assert k in args, "No argument {} in {}".format(k, l.__name__) assert k in args, "No argument {} in {}".format(k, l.__name__)
for l in layers: for l in layers:
...@@ -29,7 +46,7 @@ def argscope(layers, **param): ...@@ -29,7 +46,7 @@ def argscope(layers, **param):
new_scope = copy.copy(get_arg_scope()) new_scope = copy.copy(get_arg_scope())
for l in layers: for l in layers:
new_scope[l.__name__].update(param) new_scope[l.__name__].update(kwargs)
_ArgScopeStack.append(new_scope) _ArgScopeStack.append(new_scope)
yield yield
del _ArgScopeStack[-1] del _ArgScopeStack[-1]
...@@ -37,8 +54,10 @@ def argscope(layers, **param): ...@@ -37,8 +54,10 @@ def argscope(layers, **param):
def get_arg_scope(): def get_arg_scope():
""" """
:returns: the current argscope. Returns:
An argscope is a dict of dict: dict[layername] = {arg: val} dict: the current argscope.
An argscope is a dict of dict: ``dict[layername] = {arg: val}``
""" """
if len(_ArgScopeStack) > 0: if len(_ArgScopeStack) > 0:
return _ArgScopeStack[-1] return _ArgScopeStack[-1]
......
...@@ -28,8 +28,10 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -28,8 +28,10 @@ def get_default_sess_config(mem_fraction=0.99):
Return a better session config to use as default. Return a better session config to use as default.
Tensorflow default session config consume too much resources. Tensorflow default session config consume too much resources.
:param mem_fraction: fraction of memory to use. default to 0.99 Args:
:returns: a `tf.ConfigProto` object. mem_fraction(float): fraction of memory to use.
Returns:
tf.ConfigProto: the config to use.
""" """
conf = tf.ConfigProto() conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
...@@ -41,7 +43,11 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -41,7 +43,11 @@ def get_default_sess_config(mem_fraction=0.99):
def get_global_step_var(): def get_global_step_var():
""" :returns: the global_step variable in the current graph. create if not existed""" """
Returns:
tf.Tensor: the global_step variable in the current graph. create if
doesn't exist.
"""
try: try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError: except KeyError:
...@@ -56,7 +62,9 @@ def get_global_step_var(): ...@@ -56,7 +62,9 @@ def get_global_step_var():
def get_global_step(): def get_global_step():
""" :returns: global_step value in current graph and session""" """
Returns:
float: global_step value in current graph and session"""
return tf.train.global_step( return tf.train.global_step(
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
...@@ -66,8 +74,10 @@ def get_op_tensor_name(name): ...@@ -66,8 +74,10 @@ def get_op_tensor_name(name):
""" """
Tensor name is assumed to be ``op_name + ':0'`` Tensor name is assumed to be ``op_name + ':0'``
:param name: an op or a tensor name Args:
:returns: (op_name, tensor_name) name(str): name of an op or a tensor
Returns:
tuple: (op_name, tensor_name)
""" """
if name.endswith(':0'): if name.endswith(':0'):
return name[:-2], name return name[:-2], name
...@@ -80,7 +90,10 @@ get_op_var_name = get_op_tensor_name ...@@ -80,7 +90,10 @@ get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names): def get_tensors_by_names(names):
""" """
Get a list of tensors in the default graph by a list of names Get a list of tensors in the default graph by a list of names.
Args:
names (list):
""" """
ret = [] ret = []
G = tf.get_default_graph() G = tf.get_default_graph()
...@@ -94,6 +107,12 @@ get_vars_by_names = get_tensors_by_names ...@@ -94,6 +107,12 @@ get_vars_by_names = get_tensors_by_names
def backup_collection(keys): def backup_collection(keys):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret = {} ret = {}
for k in keys: for k in keys:
ret[k] = copy(tf.get_collection(k)) ret[k] = copy(tf.get_collection(k))
...@@ -101,22 +120,45 @@ def backup_collection(keys): ...@@ -101,22 +120,45 @@ def backup_collection(keys):
def restore_collection(backup): def restore_collection(backup):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for k, v in six.iteritems(backup): for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:] del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v) tf.get_collection_ref(k).extend(v)
def clear_collection(keys): def clear_collection(keys):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for k in keys: for k in keys:
del tf.get_collection_ref(k)[:] del tf.get_collection_ref(k)[:]
@contextmanager @contextmanager
def freeze_collection(keys): def freeze_collection(keys):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup = backup_collection(keys) backup = backup_collection(keys)
yield yield
restore_collection(backup) restore_collection(backup)
def get_tf_version(): def get_tf_version():
"""
Returns:
int:
"""
return int(tf.__version__.split('.')[1]) return int(tf.__version__.split('.')[1])
...@@ -12,16 +12,17 @@ from ..utils import logger ...@@ -12,16 +12,17 @@ from ..utils import logger
from .symbolic_functions import rms from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient', 'apply_grad_processors', 'ScaleGradient', 'apply_grad_processors']
'GlobalNormClip']
def apply_grad_processors(grads, gradprocs): def apply_grad_processors(grads, gradprocs):
""" """
:param grads: list of (grad, var). Args:
:param gradprocs: list of `GradientProcessor` instances. grads (list): list of (grad, var).
:returns: list of (grad, var) went through the processors gradprocs (list): list of :class:`GradientProcessor` instances.
Returns:
list: list of (grad, var) went through the processors.
""" """
g = [] g = []
for grad, var in grads: for grad, var in grads:
...@@ -36,13 +37,18 @@ def apply_grad_processors(grads, gradprocs): ...@@ -36,13 +37,18 @@ def apply_grad_processors(grads, gradprocs):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class GradientProcessor(object): class GradientProcessor(object):
""" Base class for all gradient processors.
Subclass should override the ``_process()`` method.
"""
def process(self, grads): def process(self, grads):
""" """
Process the symbolic gradients. Process the symbolic gradients.
:param grads: list of (grad, var) Args:
:returns: symbolic gradients with the same type as input grads (list): list of (grad, var).
Returns:
list: processed gradients, with the same type as input.
""" """
with tf.name_scope(type(self).__name__): with tf.name_scope(type(self).__name__):
return self._process(grads) return self._process(grads)
...@@ -53,10 +59,16 @@ class GradientProcessor(object): ...@@ -53,10 +59,16 @@ class GradientProcessor(object):
class GlobalNormClip(GradientProcessor): class GlobalNormClip(GradientProcessor):
""" Clip by global norm.
The global norm is the sum of norm for **all** gradients.
See :func:`tf.clip_by_global_norm` for more information.
"""
def __init__(self, global_norm): def __init__(self, global_norm):
""" Clip by global norm """
Note that the global norm is the sum of norm for **all** gradients Args:
global_norm(float): the threshold to clip with.
""" """
self._norm = global_norm self._norm = global_norm
...@@ -75,9 +87,10 @@ class MapGradient(GradientProcessor): ...@@ -75,9 +87,10 @@ class MapGradient(GradientProcessor):
def __init__(self, func, regex='.*'): def __init__(self, func, regex='.*'):
""" """
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the Args:
gradient is discarded. func: takes a grad or (grad, var) pair and returns a grad. If return None, the
:param regex: used to match variables. default to match all variables. gradient is discarded (hence no update to the variable will happen).
regex (str): used to match variables. Defaults to match all variables.
""" """
args = inspect.getargspec(func).args args = inspect.getargspec(func).args
arg_num = len(args) - inspect.ismethod(func) arg_num = len(args) - inspect.ismethod(func)
...@@ -109,7 +122,7 @@ _summaried_gradient = set() ...@@ -109,7 +122,7 @@ _summaried_gradient = set()
class SummaryGradient(MapGradient): class SummaryGradient(MapGradient):
""" """
Summary history and RMS for each graident variable Summary histogram and RMS for each graident variable.
""" """
def __init__(self): def __init__(self):
...@@ -127,6 +140,7 @@ class SummaryGradient(MapGradient): ...@@ -127,6 +140,7 @@ class SummaryGradient(MapGradient):
class CheckGradient(MapGradient): class CheckGradient(MapGradient):
""" """
Check for numeric issue. Check for numeric issue.
See :func:`tf.check_numerics` for more information.
""" """
def __init__(self): def __init__(self):
...@@ -141,13 +155,21 @@ class CheckGradient(MapGradient): ...@@ -141,13 +155,21 @@ class CheckGradient(MapGradient):
class ScaleGradient(MapGradient): class ScaleGradient(MapGradient):
""" """
Scale certain gradient by a multiplier Scale certain gradient by a multiplier.
""" """
def __init__(self, multipliers, log=True): def __init__(self, multipliers, log=True):
""" """
:param multipliers: list of (regex, float) Args:
:param log: whether to do logging or not multipliers (tuple or list): tuple of (regex, float), or list of tuples.
log (bool): whether to do logging or not
Example:
Use double learning rate for all the bias (as in caffe):
.. code-block:: python
ScaleGradient(('.*/b', 2))
""" """
if not isinstance(multipliers, list): if not isinstance(multipliers, list):
multipliers = [multipliers] multipliers = [multipliers]
......
...@@ -11,7 +11,8 @@ import six ...@@ -11,7 +11,8 @@ import six
from ..utils import logger, PREDICT_TOWER from ..utils import logger, PREDICT_TOWER
from .common import get_op_var_name from .common import get_op_var_name
from .varmanip import SessionUpdate, get_savename_from_varname, is_training_name from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
...@@ -22,12 +23,14 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore', ...@@ -22,12 +23,14 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a session""" """ Base class for utilities to initialize a session. """
def init(self, sess): def init(self, sess):
""" Initialize a session """
Initialize a session
:param sess: a `tf.Session` Args:
sess (tf.Session): the session
""" """
self._init(sess) self._init(sess)
...@@ -37,7 +40,7 @@ class SessionInit(object): ...@@ -37,7 +40,7 @@ class SessionInit(object):
class JustCurrentSession(SessionInit): class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder""" """ This is a no-op placeholder"""
def _init(self, sess): def _init(self, sess):
pass pass
...@@ -45,8 +48,7 @@ class JustCurrentSession(SessionInit): ...@@ -45,8 +48,7 @@ class JustCurrentSession(SessionInit):
class NewSession(SessionInit): class NewSession(SessionInit):
""" """
Create a new session. All variables will be initialized by their Initialize global variables by their initializer.
initializer.
""" """
def _init(self, sess): def _init(self, sess):
...@@ -55,32 +57,17 @@ class NewSession(SessionInit): ...@@ -55,32 +57,17 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
Restore an old model saved by `ModelSaver`. Restore an old model saved by :class:`ModelSaver`.
""" """
def __init__(self, model_path, prefix=None): def __init__(self, model_path, prefix=None):
""" """
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file. Args:
:param prefix: add a `prefix/` for every variable in this checkpoint model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
""" """
if os.path.basename(model_path) == model_path: model_path = get_checkpoint_path(model_path)
model_path = os.path.join('.', model_path) # avoid #4921 and #6142 self.path = model_path
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = model_path
if '00000-of-00001' in model_path:
new_path = model_path.split('.data')[0]
elif model_path.endswith('.index'):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
self.set_path(model_path)
self.prefix = prefix self.prefix = prefix
def _init(self, sess): def _init(self, sess):
...@@ -94,9 +81,6 @@ class SaverRestore(SessionInit): ...@@ -94,9 +81,6 @@ class SaverRestore(SessionInit):
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2) saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
saver.restore(sess, self.path) saver.restore(sess, self.path)
def set_path(self, model_path):
self.path = model_path
@staticmethod @staticmethod
def _produce_restore_dict(vars_multimap): def _produce_restore_dict(vars_multimap):
""" """
...@@ -161,7 +145,8 @@ class ParamRestore(SessionInit): ...@@ -161,7 +145,8 @@ class ParamRestore(SessionInit):
def __init__(self, param_dict): def __init__(self, param_dict):
""" """
:param param_dict: a dict of {name: value} Args:
param_dict (dict): a dict of {name: value}
""" """
# use varname (with :0) for consistency # use varname (with :0) for consistency
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
...@@ -190,12 +175,17 @@ class ParamRestore(SessionInit): ...@@ -190,12 +175,17 @@ class ParamRestore(SessionInit):
class ChainInit(SessionInit): class ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance.""" """ Initialize a session by a list of :class:`SessionInit` instance, executed one by one.
This can be useful for, e.g., loading several models from different files
to form a composition of models.
"""
def __init__(self, sess_inits, new_session=True): def __init__(self, sess_inits, new_session=True):
""" """
:params sess_inits: list of `SessionInit` instances. Args:
:params new_session: add a `NewSession()` and the beginning, if not there sess_inits (list): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
""" """
if new_session and not isinstance(sess_inits[0], NewSession): if new_session and not isinstance(sess_inits[0], NewSession):
sess_inits.insert(0, NewSession()) sess_inits.insert(0, NewSession())
...@@ -208,8 +198,11 @@ class ChainInit(SessionInit): ...@@ -208,8 +198,11 @@ class ChainInit(SessionInit):
def get_model_loader(filename): def get_model_loader(filename):
""" """
Get a corresponding model loader by looking at the file name Get a corresponding model loader by looking at the file name.
:return: either a ParamRestore or SaverRestore
Returns:
SessInit: either a :class:`ParamRestore` (if name ends with 'npy') or
:class:`SaverRestore` (otherwise).
""" """
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert os.path.isfile(filename), filename assert os.path.isfile(filename), filename
......
...@@ -102,8 +102,10 @@ def add_param_summary(*summary_lists): ...@@ -102,8 +102,10 @@ def add_param_summary(*summary_lists):
def add_moving_summary(v, *args): def add_moving_summary(v, *args):
""" """
:param v: tensor or list of tensor to summary Args:
:param args: tensors to summary v (tf.Tensor or list): tensor or list of tensors to summary. Must have
scalar type.
args: tensors to summary (support positional arguments)
""" """
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower: if ctx is not None and not ctx.is_main_training_tower:
...@@ -119,9 +121,14 @@ def add_moving_summary(v, *args): ...@@ -119,9 +121,14 @@ def add_moving_summary(v, *args):
@memoized @memoized
def summary_moving_average(tensors=None): def summary_moving_average(tensors=None):
""" """
Create a MovingAverage op and add summary for tensors Create a MovingAverage Op and add summary Op for all the moving averages.
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY This is called by the trainer.
:returns: a op to maintain these average.
Args:
tensors(list): list of tf.Tensor to summary. hefaults to the
collection ````MOVING_SUMMARY_VARS_KEY``.
Returns:
tf.Operation: an op to maintain these average.
""" """
if tensors is None: if tensors is None:
tensors = set(tf.get_collection(MOVING_SUMMARY_VARS_KEY)) tensors = set(tf.get_collection(MOVING_SUMMARY_VARS_KEY))
......
...@@ -8,9 +8,13 @@ import numpy as np ...@@ -8,9 +8,13 @@ import numpy as np
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
""" """
:param logits: NxC Args:
:param label: N logits: (N,C)
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction label: (N,)
topk(int): topk
Returns:
a float32 vector of length N with 0/1 values. 1 means incorrect
prediction.
""" """
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name) tf.float32, name=name)
...@@ -39,9 +43,11 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'): ...@@ -39,9 +43,11 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
as in `Holistically-Nested Edge Detection as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_. <http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1]. Args:
:param label: size: b x ANYTHING. the ground truth in {0,1}. pred: of shape (b, ...). the predictions in [0,1].
:returns: class-balanced cross entropy loss label: of the same shape. the ground truth in {0,1}.
Returns:
class-balanced cross entropy loss.
""" """
z = batch_flatten(pred) z = batch_flatten(pred)
y = tf.cast(batch_flatten(label), tf.float32) y = tf.cast(batch_flatten(label), tf.float32)
...@@ -59,14 +65,8 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'): ...@@ -59,14 +65,8 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'): def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
""" """
The class-balanced cross entropy loss, This function accepts logits rather than predictions, and is more numerically stable than
as in `Holistically-Nested Edge Detection :func:`class_balanced_cross_entropy`.
<http://arxiv.org/abs/1504.06375>`_.
This is more numerically stable than class_balanced_cross_entropy
:param logits: size: the logits.
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
""" """
y = tf.cast(label, tf.float32) y = tf.cast(label, tf.float32)
...@@ -77,17 +77,12 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss ...@@ -77,17 +77,12 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
pos_weight = beta / (1 - beta) pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight) cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name) cost = tf.reduce_mean(cost * (1 - beta), name=name)
# logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# cost = tf.sub(loss_pos, loss_neg, name=name)
return cost return cost
def print_stat(x, message=None): def print_stat(x, message=None):
""" a simple print op. """ A simple print Op that might be easier to use than :meth:`tf.Print`.
Use it like: x = print_stat(x) Use it like: ``x = print_stat(x, message='This is x')``.
""" """
if message is None: if message is None:
message = x.op.name message = x.op.name
...@@ -96,6 +91,10 @@ def print_stat(x, message=None): ...@@ -96,6 +91,10 @@ def print_stat(x, message=None):
def rms(x, name=None): def rms(x, name=None):
"""
Returns:
root mean square of tensor x.
"""
if name is None: if name is None:
name = x.op.name + '/rms' name = x.op.name + '/rms'
with tf.name_scope(None): # name already contains the scope with tf.name_scope(None): # name already contains the scope
...@@ -104,19 +103,41 @@ def rms(x, name=None): ...@@ -104,19 +103,41 @@ def rms(x, name=None):
def huber_loss(x, delta=1, name='huber_loss'): def huber_loss(x, delta=1, name='huber_loss'):
r"""
Huber loss of x.
.. math::
y = \begin{cases} \frac{x^2}{2}, & |x| < \delta \\
\delta |x| - \frac{\delta^2}{2}, & |x| \ge \delta
\end{cases}
Args:
x: the difference vector.
delta (float):
Returns:
a tensor of the same shape of x.
"""
sqrcost = tf.square(x) sqrcost = tf.square(x)
abscost = tf.abs(x) abscost = tf.abs(x)
return tf.reduce_sum( return tf.select(abscost < delta,
tf.select(abscost < delta, sqrcost * 0.5,
sqrcost * 0.5, abscost * delta - 0.5 * delta ** 2,
abscost * delta - 0.5 * delta ** 2), name=name)
name=name)
def get_scalar_var(name, init_value, summary=False, trainable=False): def get_scalar_var(name, init_value, summary=False, trainable=False):
""" """
get a scalar variable with certain initial value Get a scalar variable with certain initial value
:param summary: summary this variable
Args:
name (str): name of the variable.
init_value (float): initial value.
summary (bool): whether to summary this variable.
trainable (bool): trainable or not.
Returns:
tf.Variable: the variable
""" """
ret = tf.get_variable(name, shape=[], ret = tf.get_variable(name, shape=[],
initializer=tf.constant_initializer(init_value), initializer=tf.constant_initializer(init_value),
......
...@@ -13,9 +13,14 @@ _CurrentTowerContext = None ...@@ -13,9 +13,14 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None): def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """ """
Args:
tower_name (str): 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
"""
self._name = tower_name self._name = tower_name
if is_training is None: if is_training is None:
is_training = not self._name.startswith(PREDICT_TOWER) is_training = not self._name.startswith(PREDICT_TOWER)
...@@ -39,12 +44,15 @@ class TowerContext(object): ...@@ -39,12 +44,15 @@ class TowerContext(object):
def get_variable_on_tower(self, *args, **kwargs): def get_variable_on_tower(self, *args, **kwargs):
""" """
Get a variable for this tower specifically, without reusing. Get a variable for this tower specifically, without reusing, even if
Tensorflow doesn't allow reuse=False scope under a it is called under a ``reuse=True`` variable scope.
reuse=True scope. This method provides a work around.
Tensorflow doesn't allow us to disable reuse under a
``reuse=True`` scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable() Args:
args: same as ``tf.get_variable()``.
""" """
with tf.variable_scope(self._name) as scope: with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False): with tf.variable_scope(scope, reuse=False):
......
...@@ -14,17 +14,20 @@ from ..utils.naming import PREDICT_TOWER ...@@ -14,17 +14,20 @@ from ..utils.naming import PREDICT_TOWER
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_name'] 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path']
def get_savename_from_varname( def get_savename_from_varname(
varname, varname_prefix=None, varname, varname_prefix=None,
savename_prefix=None): savename_prefix=None):
""" """
:param varname: a variable name in the graph Args:
:param varname_prefix: an optional prefix that may need to be removed in varname varname(str): a variable name in the graph
:param savename_prefix: an optional prefix to append to all savename varname_prefix(str): an optional prefix that may need to be removed in varname
:returns: the name used to save the variable savename_prefix(str): an optional prefix to append to all savename
Returns:
str: the name used to save the variable
""" """
name = varname name = varname
if PREDICT_TOWER in name: if PREDICT_TOWER in name:
...@@ -46,7 +49,9 @@ class SessionUpdate(object): ...@@ -46,7 +49,9 @@ class SessionUpdate(object):
def __init__(self, sess, vars_to_update): def __init__(self, sess, vars_to_update):
""" """
:param vars_to_update: a collection of variables to update Args:
sess (tf.Session): a session object
vars_to_update: a collection of variables to update
""" """
self.sess = sess self.sess = sess
self.assign_ops = defaultdict(list) self.assign_ops = defaultdict(list)
...@@ -60,8 +65,9 @@ class SessionUpdate(object): ...@@ -60,8 +65,9 @@ class SessionUpdate(object):
def update(self, prms): def update(self, prms):
""" """
:param prms: dict of {variable name: value} Args:
Any name in prms must be in the graph and in vars_to_update. prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
""" """
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
assert name in self.assign_ops assert name in self.assign_ops
...@@ -77,8 +83,12 @@ class SessionUpdate(object): ...@@ -77,8 +83,12 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as """
npy format, loadable by ParamRestore Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`ParamRestore`).
Args:
path(str): the path to save the parameters.
""" """
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
...@@ -96,10 +106,42 @@ the same name".format(v.name)) ...@@ -96,10 +106,42 @@ the same name".format(v.name))
np.save(path, result) np.save(path, result)
def dump_chkpt_vars(model_path): def get_checkpoint_path(model_path):
""" Dump all variables from a checkpoint to a dict""" """
Work around TF problems in checkpoint path handling.
Args:
model_path: a user-input path
Returns:
str: the argument that can be passed to NewCheckpointReader
"""
if os.path.basename(model_path) == model_path: if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = model_path
if '00000-of-00001' in model_path:
new_path = model_path.split('.data')[0]
elif model_path.endswith('.index'):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
return model_path
def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict.
Args:
model_path(str): path to a checkpoint.
"""
model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys() var_names = reader.get_variable_to_shape_map().keys()
result = {} result = {}
...@@ -110,8 +152,10 @@ def dump_chkpt_vars(model_path): ...@@ -110,8 +152,10 @@ def dump_chkpt_vars(model_path):
def is_training_name(name): def is_training_name(name):
""" """
This is only used to improve logging. This is a hack temporarily used to improve logging. Do not use this function.
:returns: guess whether this tensor is something only used in training.
Returns:
bool: Guess whether this tensor is something only used in training.
""" """
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES? # TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names() # TODO or use get_slot_names()
......
...@@ -14,6 +14,13 @@ _ORIG_GET_VARIABLE = tf.get_variable ...@@ -14,6 +14,13 @@ _ORIG_GET_VARIABLE = tf.get_variable
@contextmanager @contextmanager
def replace_get_variable(fn): def replace_get_variable(fn):
"""
Args:
fn: a function taking the same arguments as ``tf.get_variable``.
Returns:
a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``.
"""
old_getv = tf.get_variable old_getv = tf.get_variable
old_vars_getv = variable_scope.get_variable old_vars_getv = variable_scope.get_variable
...@@ -26,8 +33,10 @@ def replace_get_variable(fn): ...@@ -26,8 +33,10 @@ def replace_get_variable(fn):
def freeze_get_variable(): def freeze_get_variable():
""" """
Return a contextmanager, where all variables returned by Return a context, where all variables (reused or not) returned by
`get_variable` will have no gradients. ``get_variable`` will have no gradients (surrounded by ``tf.stop_gradient``).
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get
saved correctly. This is useful to fix certain variables for fine-tuning.
Example: Example:
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment