Commit 6f55416f authored by Yuxin Wu's avatar Yuxin Wu

update API docs on tfutils

parent 403815b5
...@@ -358,17 +358,18 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -358,17 +358,18 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'replace_get_variable', 'replace_get_variable',
'remap_get_variable', 'remap_get_variable',
'freeze_get_variable', 'freeze_get_variable',
'Triggerable',
'predictor_factory', 'predictor_factory',
'get_predictors', 'get_predictors',
'RandomCropAroundBox', 'RandomCropAroundBox',
'GaussianDeform', 'GaussianDeform',
'dump_chkpt_vars', 'dump_chkpt_vars',
'VisualQA', 'VisualQA',
'huber_loss',
'DumpTensor', 'DumpTensor',
'StagingInputWrapper', 'StagingInputWrapper',
'StepTensorPrinter' 'StepTensorPrinter',
'guided_relu', 'saliency_map', 'get_scalar_var',
'prediction_incorrect', 'huber_loss',
]: ]:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
......
...@@ -23,6 +23,14 @@ tensorpack.tfutils.gradproc module ...@@ -23,6 +23,14 @@ tensorpack.tfutils.gradproc module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.tfutils.tower module
------------------------------------
.. automodule:: tensorpack.tfutils.tower
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.scope_utils module tensorpack.tfutils.scope_utils module
-------------------------------------- --------------------------------------
...@@ -47,6 +55,14 @@ tensorpack.tfutils.sesscreate module ...@@ -47,6 +55,14 @@ tensorpack.tfutils.sesscreate module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.tfutils.sessinit module
------------------------------------
.. automodule:: tensorpack.tfutils.sessinit
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.summary module tensorpack.tfutils.summary module
--------------------------------- ---------------------------------
...@@ -79,11 +95,11 @@ tensorpack.tfutils.varreplace module ...@@ -79,11 +95,11 @@ tensorpack.tfutils.varreplace module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
Module contents Other functions in tensorpack.tfutils module
--------------- ---------------------------------------------
.. automodule:: tensorpack.tfutils
:members:
:undoc-members:
:show-inheritance:
.. automethod:: tensorpack.tfutils.get_default_sess_config
.. automethod:: tensorpack.tfutils.get_global_step_var
.. automethod:: tensorpack.tfutils.get_global_step_value
.. automethod:: tensorpack.tfutils.argscope
.. automethod:: tensorpack.tfutils.get_arg_scope
...@@ -8,7 +8,7 @@ import six ...@@ -8,7 +8,7 @@ import six
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -206,10 +206,6 @@ class Callback(object): ...@@ -206,10 +206,6 @@ class Callback(object):
return type(self).__name__ return type(self).__name__
# back-compat. in case someone write something in triggerable
Triggerable = Callback
class ProxyCallback(Callback): class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback. """ A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks. It's useful as a base class of callbacks which decorate other callbacks.
......
...@@ -22,10 +22,13 @@ def _global_import(name): ...@@ -22,10 +22,13 @@ def _global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = ['utils', 'registry']
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath): if not os.path.isfile(srcpath):
continue continue
if not module_name.startswith('_'): if module_name.startswith('_'):
continue
if module_name not in _SKIP:
_global_import(module_name) _global_import(module_name)
...@@ -4,3 +4,5 @@ ...@@ -4,3 +4,5 @@
from .registry import layer_register # noqa from .registry import layer_register # noqa
from .utils import VariableHolder, rename_get_variable # noqa from .utils import VariableHolder, rename_get_variable # noqa
__all__ = ['layer_register', 'VariableHolder']
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
from pkgutil import iter_modules from pkgutil import iter_modules
import os import os
__all__ = [] from .tower import get_current_tower_context, TowerContext
# don't want to include everything from .tower
__all__ = ['get_current_tower_context', 'TowerContext']
def _global_import(name): def _global_import(name):
...@@ -21,7 +23,6 @@ _TO_IMPORT = set([ ...@@ -21,7 +23,6 @@ _TO_IMPORT = set([
'common', 'common',
'sessinit', 'sessinit',
'argscope', 'argscope',
'tower',
]) ])
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
...@@ -36,4 +37,4 @@ for _, module_name, _ in iter_modules( ...@@ -36,4 +37,4 @@ for _, module_name, _ in iter_modules(
_global_import(module_name) # import the content to tfutils.* _global_import(module_name) # import the content to tfutils.*
__all__.extend(['sessinit', 'summary', 'optimizer', __all__.extend(['sessinit', 'summary', 'optimizer',
'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions', 'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions',
'distributed']) 'distributed', 'tower'])
...@@ -13,15 +13,6 @@ from ..utils.develop import deprecated ...@@ -13,15 +13,6 @@ from ..utils.develop import deprecated
# this function exists for backwards-compatibilty # this function exists for backwards-compatibilty
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
Args:
logits: shape [B,C].
label: shape [B].
topk(int): topk
Returns:
a float32 vector of length N with 0/1 values. 1 means incorrect
prediction.
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name) tf.float32, name=name)
......
...@@ -8,7 +8,7 @@ from contextlib import contextmanager ...@@ -8,7 +8,7 @@ from contextlib import contextmanager
from ..utils.develop import deprecated from ..utils.develop import deprecated
__all__ = ['custom_getter_scope', 'replace_get_variable', __all__ = ['replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable', 'freeze_variables', 'freeze_get_variable', 'remap_get_variable',
'remap_variables'] 'remap_variables']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment