Commit 6f55416f authored by Yuxin Wu's avatar Yuxin Wu

update API docs on tfutils

parent 403815b5
......@@ -358,17 +358,18 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'replace_get_variable',
'remap_get_variable',
'freeze_get_variable',
'Triggerable',
'predictor_factory',
'get_predictors',
'RandomCropAroundBox',
'GaussianDeform',
'dump_chkpt_vars',
'VisualQA',
'huber_loss',
'DumpTensor',
'StagingInputWrapper',
'StepTensorPrinter'
'StepTensorPrinter',
'guided_relu', 'saliency_map', 'get_scalar_var',
'prediction_incorrect', 'huber_loss',
]:
return True
if name in ['get_data', 'size', 'reset_state']:
......
......@@ -23,6 +23,14 @@ tensorpack.tfutils.gradproc module
:undoc-members:
:show-inheritance:
tensorpack.tfutils.tower module
------------------------------------
.. automodule:: tensorpack.tfutils.tower
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.scope_utils module
--------------------------------------
......@@ -47,6 +55,14 @@ tensorpack.tfutils.sesscreate module
:undoc-members:
:show-inheritance:
tensorpack.tfutils.sessinit module
------------------------------------
.. automodule:: tensorpack.tfutils.sessinit
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.summary module
---------------------------------
......@@ -79,11 +95,11 @@ tensorpack.tfutils.varreplace module
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: tensorpack.tfutils
:members:
:undoc-members:
:show-inheritance:
Other functions in tensorpack.tfutils module
---------------------------------------------
.. automethod:: tensorpack.tfutils.get_default_sess_config
.. automethod:: tensorpack.tfutils.get_global_step_var
.. automethod:: tensorpack.tfutils.get_global_step_value
.. automethod:: tensorpack.tfutils.argscope
.. automethod:: tensorpack.tfutils.get_arg_scope
......@@ -8,7 +8,7 @@ import six
from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta)
......@@ -206,10 +206,6 @@ class Callback(object):
return type(self).__name__
# back-compat. in case someone write something in triggerable
Triggerable = Callback
class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
......
......@@ -22,10 +22,13 @@ def _global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['utils', 'registry']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if not module_name.startswith('_'):
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
_global_import(module_name)
......@@ -4,3 +4,5 @@
from .registry import layer_register # noqa
from .utils import VariableHolder, rename_get_variable # noqa
__all__ = ['layer_register', 'VariableHolder']
......@@ -5,7 +5,9 @@
from pkgutil import iter_modules
import os
__all__ = []
from .tower import get_current_tower_context, TowerContext
# don't want to include everything from .tower
__all__ = ['get_current_tower_context', 'TowerContext']
def _global_import(name):
......@@ -21,7 +23,6 @@ _TO_IMPORT = set([
'common',
'sessinit',
'argscope',
'tower',
])
_CURR_DIR = os.path.dirname(__file__)
......@@ -36,4 +37,4 @@ for _, module_name, _ in iter_modules(
_global_import(module_name) # import the content to tfutils.*
__all__.extend(['sessinit', 'summary', 'optimizer',
'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions',
'distributed'])
'distributed', 'tower'])
......@@ -13,15 +13,6 @@ from ..utils.develop import deprecated
# this function exists for backwards-compatibilty
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
Args:
logits: shape [B,C].
label: shape [B].
topk(int): topk
Returns:
a float32 vector of length N with 0/1 values. 1 means incorrect
prediction.
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name)
......
......@@ -8,7 +8,7 @@ from contextlib import contextmanager
from ..utils.develop import deprecated
__all__ = ['custom_getter_scope', 'replace_get_variable',
__all__ = ['replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable',
'remap_variables']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment