Commit 06ea1c0a authored by Yuxin Wu's avatar Yuxin Wu

api docs for tfutils/

parent bbf41d9e
......@@ -67,8 +67,7 @@ extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
#'sphinx.ext.coverage',
#'sphinx.ext.mathjax',
'sphinx.ext.mathbase',
'sphinx.ext.mathjax',
'sphinx.ext.intersphinx',
'sphinx.ext.viewcode',
]
......
......@@ -132,8 +132,8 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value),
tf.cast(BATCH_SIZE, tf.float32), name='cost')
self.cost = tf.reduce_mean(symbf.huber_loss(
target - pred_action_value), name='cost')
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W
add_moving_summary(self.cost)
......
......@@ -88,7 +88,7 @@ class Model(ModelDesc):
def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5)[0][0]),
ScaleGradient([('STN.*', 0.1)]), SummaryGradient()]
ScaleGradient(('STN.*', 0.1)), SummaryGradient()]
def get_data(isTrain):
......
......@@ -34,5 +34,4 @@ for _, module_name, _ in walk_packages(
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
if module_name != 'common':
__all__.append(module_name)
__all__.extend(['sessinit', 'gradproc'])
......@@ -14,13 +14,30 @@ _ArgScopeStack = []
@contextmanager
def argscope(layers, **param):
def argscope(layers, **kwargs):
"""
Args:
layers (list or layer): layer or list of layers to apply the arguments.
Returns:
a context where all appearance of these layer will by default have the
arguments specified by kwargs.
Example:
.. code-block:: python
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
x = Conv2D('conv0', x)
x = Conv2D('conv1', x)
x = Conv2D('conv2', x, out_channel=64) # override argscope
"""
if not isinstance(layers, list):
layers = [layers]
def _check_args_exist(l):
args = inspect.getargspec(l).args
for k, v in six.iteritems(param):
for k, v in six.iteritems(kwargs):
assert k in args, "No argument {} in {}".format(k, l.__name__)
for l in layers:
......@@ -29,7 +46,7 @@ def argscope(layers, **param):
new_scope = copy.copy(get_arg_scope())
for l in layers:
new_scope[l.__name__].update(param)
new_scope[l.__name__].update(kwargs)
_ArgScopeStack.append(new_scope)
yield
del _ArgScopeStack[-1]
......@@ -37,8 +54,10 @@ def argscope(layers, **param):
def get_arg_scope():
"""
:returns: the current argscope.
An argscope is a dict of dict: dict[layername] = {arg: val}
Returns:
dict: the current argscope.
An argscope is a dict of dict: ``dict[layername] = {arg: val}``
"""
if len(_ArgScopeStack) > 0:
return _ArgScopeStack[-1]
......
......@@ -28,8 +28,10 @@ def get_default_sess_config(mem_fraction=0.99):
Return a better session config to use as default.
Tensorflow default session config consume too much resources.
:param mem_fraction: fraction of memory to use. default to 0.99
:returns: a `tf.ConfigProto` object.
Args:
mem_fraction(float): fraction of memory to use.
Returns:
tf.ConfigProto: the config to use.
"""
conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
......@@ -41,7 +43,11 @@ def get_default_sess_config(mem_fraction=0.99):
def get_global_step_var():
""" :returns: the global_step variable in the current graph. create if not existed"""
"""
Returns:
tf.Tensor: the global_step variable in the current graph. create if
doesn't exist.
"""
try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError:
......@@ -56,7 +62,9 @@ def get_global_step_var():
def get_global_step():
""" :returns: global_step value in current graph and session"""
"""
Returns:
float: global_step value in current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
......@@ -66,8 +74,10 @@ def get_op_tensor_name(name):
"""
Tensor name is assumed to be ``op_name + ':0'``
:param name: an op or a tensor name
:returns: (op_name, tensor_name)
Args:
name(str): name of an op or a tensor
Returns:
tuple: (op_name, tensor_name)
"""
if name.endswith(':0'):
return name[:-2], name
......@@ -80,7 +90,10 @@ get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names):
"""
Get a list of tensors in the default graph by a list of names
Get a list of tensors in the default graph by a list of names.
Args:
names (list):
"""
ret = []
G = tf.get_default_graph()
......@@ -94,6 +107,12 @@ get_vars_by_names = get_tensors_by_names
def backup_collection(keys):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
......@@ -101,22 +120,45 @@ def backup_collection(keys):
def restore_collection(backup):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup = backup_collection(keys)
yield
restore_collection(backup)
def get_tf_version():
"""
Returns:
int:
"""
return int(tf.__version__.split('.')[1])
......@@ -12,16 +12,17 @@ from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient', 'apply_grad_processors',
'GlobalNormClip']
__all__ = ['GradientProcessor', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'apply_grad_processors']
def apply_grad_processors(grads, gradprocs):
"""
:param grads: list of (grad, var).
:param gradprocs: list of `GradientProcessor` instances.
:returns: list of (grad, var) went through the processors
Args:
grads (list): list of (grad, var).
gradprocs (list): list of :class:`GradientProcessor` instances.
Returns:
list: list of (grad, var) went through the processors.
"""
g = []
for grad, var in grads:
......@@ -36,13 +37,18 @@ def apply_grad_processors(grads, gradprocs):
@six.add_metaclass(ABCMeta)
class GradientProcessor(object):
""" Base class for all gradient processors.
Subclass should override the ``_process()`` method.
"""
def process(self, grads):
"""
Process the symbolic gradients.
:param grads: list of (grad, var)
:returns: symbolic gradients with the same type as input
Args:
grads (list): list of (grad, var).
Returns:
list: processed gradients, with the same type as input.
"""
with tf.name_scope(type(self).__name__):
return self._process(grads)
......@@ -53,10 +59,16 @@ class GradientProcessor(object):
class GlobalNormClip(GradientProcessor):
""" Clip by global norm.
The global norm is the sum of norm for **all** gradients.
See :func:`tf.clip_by_global_norm` for more information.
"""
def __init__(self, global_norm):
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
"""
Args:
global_norm(float): the threshold to clip with.
"""
self._norm = global_norm
......@@ -75,9 +87,10 @@ class MapGradient(GradientProcessor):
def __init__(self, func, regex='.*'):
"""
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
Args:
func: takes a grad or (grad, var) pair and returns a grad. If return None, the
gradient is discarded (hence no update to the variable will happen).
regex (str): used to match variables. Defaults to match all variables.
"""
args = inspect.getargspec(func).args
arg_num = len(args) - inspect.ismethod(func)
......@@ -109,7 +122,7 @@ _summaried_gradient = set()
class SummaryGradient(MapGradient):
"""
Summary history and RMS for each graident variable
Summary histogram and RMS for each graident variable.
"""
def __init__(self):
......@@ -127,6 +140,7 @@ class SummaryGradient(MapGradient):
class CheckGradient(MapGradient):
"""
Check for numeric issue.
See :func:`tf.check_numerics` for more information.
"""
def __init__(self):
......@@ -141,13 +155,21 @@ class CheckGradient(MapGradient):
class ScaleGradient(MapGradient):
"""
Scale certain gradient by a multiplier
Scale certain gradient by a multiplier.
"""
def __init__(self, multipliers, log=True):
"""
:param multipliers: list of (regex, float)
:param log: whether to do logging or not
Args:
multipliers (tuple or list): tuple of (regex, float), or list of tuples.
log (bool): whether to do logging or not
Example:
Use double learning rate for all the bias (as in caffe):
.. code-block:: python
ScaleGradient(('.*/b', 2))
"""
if not isinstance(multipliers, list):
multipliers = [multipliers]
......
......@@ -11,7 +11,8 @@ import six
from ..utils import logger, PREDICT_TOWER
from .common import get_op_var_name
from .varmanip import SessionUpdate, get_savename_from_varname, is_training_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
......@@ -22,12 +23,14 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
@six.add_metaclass(ABCMeta)
class SessionInit(object):
""" Base class for utilities to initialize a session"""
""" Base class for utilities to initialize a session. """
def init(self, sess):
""" Initialize a session
"""
Initialize a session
:param sess: a `tf.Session`
Args:
sess (tf.Session): the session
"""
self._init(sess)
......@@ -37,7 +40,7 @@ class SessionInit(object):
class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder"""
""" This is a no-op placeholder"""
def _init(self, sess):
pass
......@@ -45,8 +48,7 @@ class JustCurrentSession(SessionInit):
class NewSession(SessionInit):
"""
Create a new session. All variables will be initialized by their
initializer.
Initialize global variables by their initializer.
"""
def _init(self, sess):
......@@ -55,32 +57,17 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit):
"""
Restore an old model saved by `ModelSaver`.
Restore an old model saved by :class:`ModelSaver`.
"""
def __init__(self, model_path, prefix=None):
"""
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
Args:
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
"""
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = model_path
if '00000-of-00001' in model_path:
new_path = model_path.split('.data')[0]
elif model_path.endswith('.index'):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
self.set_path(model_path)
model_path = get_checkpoint_path(model_path)
self.path = model_path
self.prefix = prefix
def _init(self, sess):
......@@ -94,9 +81,6 @@ class SaverRestore(SessionInit):
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
saver.restore(sess, self.path)
def set_path(self, model_path):
self.path = model_path
@staticmethod
def _produce_restore_dict(vars_multimap):
"""
......@@ -161,7 +145,8 @@ class ParamRestore(SessionInit):
def __init__(self, param_dict):
"""
:param param_dict: a dict of {name: value}
Args:
param_dict (dict): a dict of {name: value}
"""
# use varname (with :0) for consistency
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
......@@ -190,12 +175,17 @@ class ParamRestore(SessionInit):
class ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance."""
""" Initialize a session by a list of :class:`SessionInit` instance, executed one by one.
This can be useful for, e.g., loading several models from different files
to form a composition of models.
"""
def __init__(self, sess_inits, new_session=True):
"""
:params sess_inits: list of `SessionInit` instances.
:params new_session: add a `NewSession()` and the beginning, if not there
Args:
sess_inits (list): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
"""
if new_session and not isinstance(sess_inits[0], NewSession):
sess_inits.insert(0, NewSession())
......@@ -208,8 +198,11 @@ class ChainInit(SessionInit):
def get_model_loader(filename):
"""
Get a corresponding model loader by looking at the file name
:return: either a ParamRestore or SaverRestore
Get a corresponding model loader by looking at the file name.
Returns:
SessInit: either a :class:`ParamRestore` (if name ends with 'npy') or
:class:`SaverRestore` (otherwise).
"""
if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
......
......@@ -102,8 +102,10 @@ def add_param_summary(*summary_lists):
def add_moving_summary(v, *args):
"""
:param v: tensor or list of tensor to summary
:param args: tensors to summary
Args:
v (tf.Tensor or list): tensor or list of tensors to summary. Must have
scalar type.
args: tensors to summary (support positional arguments)
"""
ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower:
......@@ -119,9 +121,14 @@ def add_moving_summary(v, *args):
@memoized
def summary_moving_average(tensors=None):
"""
Create a MovingAverage op and add summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average.
Create a MovingAverage Op and add summary Op for all the moving averages.
This is called by the trainer.
Args:
tensors(list): list of tf.Tensor to summary. hefaults to the
collection ````MOVING_SUMMARY_VARS_KEY``.
Returns:
tf.Operation: an op to maintain these average.
"""
if tensors is None:
tensors = set(tf.get_collection(MOVING_SUMMARY_VARS_KEY))
......
......@@ -8,9 +8,13 @@ import numpy as np
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
:param logits: NxC
:param label: N
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
Args:
logits: (N,C)
label: (N,)
topk(int): topk
Returns:
a float32 vector of length N with 0/1 values. 1 means incorrect
prediction.
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name)
......@@ -39,9 +43,11 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1].
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced cross entropy loss
Args:
pred: of shape (b, ...). the predictions in [0,1].
label: of the same shape. the ground truth in {0,1}.
Returns:
class-balanced cross entropy loss.
"""
z = batch_flatten(pred)
y = tf.cast(batch_flatten(label), tf.float32)
......@@ -59,14 +65,8 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
This is more numerically stable than class_balanced_cross_entropy
:param logits: size: the logits.
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
This function accepts logits rather than predictions, and is more numerically stable than
:func:`class_balanced_cross_entropy`.
"""
y = tf.cast(label, tf.float32)
......@@ -77,17 +77,12 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
# logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# cost = tf.sub(loss_pos, loss_neg, name=name)
return cost
def print_stat(x, message=None):
""" a simple print op.
Use it like: x = print_stat(x)
""" A simple print Op that might be easier to use than :meth:`tf.Print`.
Use it like: ``x = print_stat(x, message='This is x')``.
"""
if message is None:
message = x.op.name
......@@ -96,6 +91,10 @@ def print_stat(x, message=None):
def rms(x, name=None):
"""
Returns:
root mean square of tensor x.
"""
if name is None:
name = x.op.name + '/rms'
with tf.name_scope(None): # name already contains the scope
......@@ -104,19 +103,41 @@ def rms(x, name=None):
def huber_loss(x, delta=1, name='huber_loss'):
r"""
Huber loss of x.
.. math::
y = \begin{cases} \frac{x^2}{2}, & |x| < \delta \\
\delta |x| - \frac{\delta^2}{2}, & |x| \ge \delta
\end{cases}
Args:
x: the difference vector.
delta (float):
Returns:
a tensor of the same shape of x.
"""
sqrcost = tf.square(x)
abscost = tf.abs(x)
return tf.reduce_sum(
tf.select(abscost < delta,
sqrcost * 0.5,
abscost * delta - 0.5 * delta ** 2),
name=name)
return tf.select(abscost < delta,
sqrcost * 0.5,
abscost * delta - 0.5 * delta ** 2,
name=name)
def get_scalar_var(name, init_value, summary=False, trainable=False):
"""
get a scalar variable with certain initial value
:param summary: summary this variable
Get a scalar variable with certain initial value
Args:
name (str): name of the variable.
init_value (float): initial value.
summary (bool): whether to summary this variable.
trainable (bool): trainable or not.
Returns:
tf.Variable: the variable
"""
ret = tf.get_variable(name, shape=[],
initializer=tf.constant_initializer(init_value),
......
......@@ -13,9 +13,14 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
"""
self._name = tower_name
if is_training is None:
is_training = not self._name.startswith(PREDICT_TOWER)
......@@ -39,12 +44,15 @@ class TowerContext(object):
def get_variable_on_tower(self, *args, **kwargs):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
Get a variable for this tower specifically, without reusing, even if
it is called under a ``reuse=True`` variable scope.
Tensorflow doesn't allow us to disable reuse under a
``reuse=True`` scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
Args:
args: same as ``tf.get_variable()``.
"""
with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False):
......
......@@ -14,17 +14,20 @@ from ..utils.naming import PREDICT_TOWER
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_name']
'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path']
def get_savename_from_varname(
varname, varname_prefix=None,
savename_prefix=None):
"""
:param varname: a variable name in the graph
:param varname_prefix: an optional prefix that may need to be removed in varname
:param savename_prefix: an optional prefix to append to all savename
:returns: the name used to save the variable
Args:
varname(str): a variable name in the graph
varname_prefix(str): an optional prefix that may need to be removed in varname
savename_prefix(str): an optional prefix to append to all savename
Returns:
str: the name used to save the variable
"""
name = varname
if PREDICT_TOWER in name:
......@@ -46,7 +49,9 @@ class SessionUpdate(object):
def __init__(self, sess, vars_to_update):
"""
:param vars_to_update: a collection of variables to update
Args:
sess (tf.Session): a session object
vars_to_update: a collection of variables to update
"""
self.sess = sess
self.assign_ops = defaultdict(list)
......@@ -60,8 +65,9 @@ class SessionUpdate(object):
def update(self, prms):
"""
:param prms: dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
Args:
prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for name, value in six.iteritems(prms):
assert name in self.assign_ops
......@@ -77,8 +83,12 @@ class SessionUpdate(object):
def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`ParamRestore`).
Args:
path(str): the path to save the parameters.
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
......@@ -96,10 +106,42 @@ the same name".format(v.name))
np.save(path, result)
def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict"""
def get_checkpoint_path(model_path):
"""
Work around TF problems in checkpoint path handling.
Args:
model_path: a user-input path
Returns:
str: the argument that can be passed to NewCheckpointReader
"""
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = model_path
if '00000-of-00001' in model_path:
new_path = model_path.split('.data')[0]
elif model_path.endswith('.index'):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
return model_path
def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict.
Args:
model_path(str): path to a checkpoint.
"""
model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
......@@ -110,8 +152,10 @@ def dump_chkpt_vars(model_path):
def is_training_name(name):
"""
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
This is a hack temporarily used to improve logging. Do not use this function.
Returns:
bool: Guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names()
......
......@@ -14,6 +14,13 @@ _ORIG_GET_VARIABLE = tf.get_variable
@contextmanager
def replace_get_variable(fn):
"""
Args:
fn: a function taking the same arguments as ``tf.get_variable``.
Returns:
a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``.
"""
old_getv = tf.get_variable
old_vars_getv = variable_scope.get_variable
......@@ -26,8 +33,10 @@ def replace_get_variable(fn):
def freeze_get_variable():
"""
Return a contextmanager, where all variables returned by
`get_variable` will have no gradients.
Return a context, where all variables (reused or not) returned by
``get_variable`` will have no gradients (surrounded by ``tf.stop_gradient``).
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get
saved correctly. This is useful to fix certain variables for fine-tuning.
Example:
.. code-block:: python
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment