Commit a47c9980 authored by Yuxin Wu's avatar Yuxin Wu

hide some internal functions to develop.py.

parent 3e61aacd
tensorpack.tfutils package tensorpack.tfutils package
========================== ==========================
tensorpack.tfutils.collection module
------------------------------------
.. automodule:: tensorpack.tfutils.collection
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.distributions module
---------------------------------------
.. automodule:: tensorpack.tfutils.distributions
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.gradproc module
------------------------------------
.. automodule:: tensorpack.tfutils.gradproc
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.modelutils module tensorpack.tfutils.modelutils module
------------------------------------ ------------------------------------
...@@ -9,6 +33,22 @@ tensorpack.tfutils.modelutils module ...@@ -9,6 +33,22 @@ tensorpack.tfutils.modelutils module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.tfutils.optimizer module
------------------------------------
.. automodule:: tensorpack.tfutils.optimizer
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.sesscreate module
------------------------------------
.. automodule:: tensorpack.tfutils.sesscreate
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.summary module tensorpack.tfutils.summary module
--------------------------------- ---------------------------------
......
...@@ -17,13 +17,6 @@ tensorpack.utils.concurrency module ...@@ -17,13 +17,6 @@ tensorpack.utils.concurrency module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.utils.debug module
-----------------------------
.. automodule:: tensorpack.utils.debug
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.discretize module tensorpack.utils.discretize module
---------------------------------- ----------------------------------
......
...@@ -9,7 +9,7 @@ from collections import defaultdict ...@@ -9,7 +9,7 @@ from collections import defaultdict
import six import six
from ..utils import get_rng from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer', __all__ = ['RLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace'] 'DiscreteActionSpace']
......
...@@ -79,7 +79,7 @@ try: ...@@ -79,7 +79,7 @@ try:
# https://github.com/openai/gym/pull/199 # https://github.com/openai/gym/pull/199
# not sure does it cause other problems # not sure does it cause other problems
except ImportError: except ImportError:
from ..utils.dependency import create_dummy_class from ..utils.develop import create_dummy_class
GymEnv = create_dummy_class('GymEnv', 'gym') # noqa GymEnv = create_dummy_class('GymEnv', 'gym') # noqa
......
...@@ -91,7 +91,7 @@ class BSDS500(RNGDataFlow): ...@@ -91,7 +91,7 @@ class BSDS500(RNGDataFlow):
try: try:
from scipy.io import loadmat from scipy.io import loadmat
except ImportError: except ImportError:
from ...utils.dependency import create_dummy_class from ...utils.develop import create_dummy_class
BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,7 +72,7 @@ class SVHNDigit(RNGDataFlow): ...@@ -72,7 +72,7 @@ class SVHNDigit(RNGDataFlow):
try: try:
import scipy.io import scipy.io
except ImportError: except ImportError:
from ...utils.dependency import create_dummy_class from ...utils.develop import create_dummy_class
SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -82,7 +82,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path): ...@@ -82,7 +82,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
try: try:
import lmdb import lmdb
except ImportError: except ImportError:
from ..utils.dependency import create_dummy_func from ..utils.develop import create_dummy_func
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
......
...@@ -181,9 +181,9 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None): ...@@ -181,9 +181,9 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
a :class:`LMDBDataDecoder` instance. a :class:`LMDBDataDecoder` instance.
Example: Example:
.. code-block:: python
.. code-block:: none ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
""" """
cpb = get_caffe_pb() cpb = get_caffe_pb()
...@@ -226,7 +226,7 @@ class SVMLightData(RNGDataFlow): ...@@ -226,7 +226,7 @@ class SVMLightData(RNGDataFlow):
yield [self.X[id, :], self.y[id]] yield [self.X[id, :], self.y[id]]
from ..utils.dependency import create_dummy_class # noqa from ..utils.develop import create_dummy_class # noqa
try: try:
import h5py import h5py
except ImportError: except ImportError:
......
...@@ -217,7 +217,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -217,7 +217,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
""" """
Batch Renormalization layer, as described in the paper: Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
<https://arxiv.org/abs/1702.03275>`_. <https://arxiv.org/abs/1702.03275>`_.
Args: Args:
x (tf.Tensor): a NHWC or NC tensor. x (tf.Tensor): a NHWC or NC tensor.
......
...@@ -10,7 +10,8 @@ import copy ...@@ -10,7 +10,8 @@ import copy
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary from ..tfutils.summary import add_activation_summary
from ..utils import logger, building_rtfd from ..utils import logger
from ..utils.develop import building_rtfd
# make sure each layer is only logged once # make sure each layer is only logged once
_LAYER_LOGGED = set() _LAYER_LOGGED = set()
......
...@@ -8,7 +8,9 @@ import tensorflow as tf ...@@ -8,7 +8,9 @@ import tensorflow as tf
import pickle import pickle
import six import six
from ..utils import logger, INPUTS_KEY, deprecated, log_deprecated from ..utils import logger
from ..utils.naming import INPUTS_KEY
from ..utils.develop import deprecated, log_deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.modelutils import apply_slim_collections from ..tfutils.modelutils import apply_slim_collections
......
...@@ -7,7 +7,8 @@ from abc import abstractmethod, ABCMeta ...@@ -7,7 +7,8 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils import logger, deprecated from ..utils import logger
from ..utils.develop import deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
...@@ -60,8 +61,10 @@ class PredictorBase(object): ...@@ -60,8 +61,10 @@ class PredictorBase(object):
@abstractmethod @abstractmethod
def _do_call(self, dp): def _do_call(self, dp):
""" """
:param dp: input datapoint. must have the same length as input_names Args:
:return: output as defined by the config dp: input datapoint. must have the same length as input_names
Returns:
output as defined by the config
""" """
......
...@@ -8,7 +8,8 @@ import six ...@@ -8,7 +8,8 @@ import six
from six.moves import queue, range from six.moves import queue, range
import tensorflow as tf import tensorflow as tf
from ..utils import logger, deprecated from ..utils import logger
from ..utils.develop import deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
...@@ -185,5 +186,5 @@ try: ...@@ -185,5 +186,5 @@ try:
else: else:
from concurrent.futures import Future from concurrent.futures import Future
except ImportError: except ImportError:
from ..utils.dependency import create_dummy_class from ..utils.develop import create_dummy_class
MultiThreadAsyncPredictor = create_dummy_class('MultiThreadAsyncPredictor', 'tornado.concurrent') # noqa MultiThreadAsyncPredictor = create_dummy_class('MultiThreadAsyncPredictor', 'tornado.concurrent') # noqa
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
import six import six
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSession from ..tfutils.sesscreate import NewSession
......
...@@ -76,10 +76,11 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -76,10 +76,11 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class DataParallelOfflinePredictor(OnlinePredictor): class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor. """
Its input is: [input[0] in tower[0], input[1] in tower[0], ..., A data-parallel predictor.
input[0] in tower[1], input[1] in tower[1], ...] Note that it doesn't split/concat inputs/outputs automatically.
And same for the output. Its input is: ``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
And same for the output.
""" """
def __init__(self, config, towers): def __init__(self, config, towers):
......
...@@ -6,7 +6,8 @@ import six ...@@ -6,7 +6,8 @@ import six
import tensorflow as tf import tensorflow as tf
import re import re
from ..utils import log_deprecated, logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context from .tower import get_current_tower_context
from .symbolic_functions import rms from .symbolic_functions import rms
......
...@@ -12,7 +12,8 @@ from six.moves import range ...@@ -12,7 +12,8 @@ from six.moves import range
import tensorflow as tf import tensorflow as tf
from .predict import PredictorFactory from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger, deprecated, log_deprecated from ..utils import logger
from ..utils.develop import deprecated, log_deprecated
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
......
...@@ -10,7 +10,8 @@ from ..callbacks import ( ...@@ -10,7 +10,8 @@ from ..callbacks import (
MaintainStepCounter) MaintainStepCounter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger, log_deprecated from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.optimizer import apply_grad_processors from ..tfutils.optimizer import apply_grad_processors
...@@ -130,6 +131,8 @@ class TrainConfig(object): ...@@ -130,6 +131,8 @@ class TrainConfig(object):
self.predict_tower = predict_tower self.predict_tower = predict_tower
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
assert len(set(self.predict_tower)) == len(self.predict_tower), \
"Cannot have duplicated predict_tower!"
if 'optimizer' in kwargs: if 'optimizer' in kwargs:
log_deprecated("TrainConfig(optimizer=...)", log_deprecated("TrainConfig(optimizer=...)",
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_data import QueueInput, FeedfreeInput from .input_data import QueueInput, FeedfreeInput
......
...@@ -8,7 +8,8 @@ import itertools ...@@ -8,7 +8,8 @@ import itertools
import re import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger, log_deprecated from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
......
...@@ -33,7 +33,7 @@ class PredictorFactory(object): ...@@ -33,7 +33,7 @@ class PredictorFactory(object):
Returns: Returns:
an online predictor (which has to be used under a default session) an online predictor (which has to be used under a default session)
""" """
tower = self.towers[tower] # TODO is it good? tower = self.towers[tower]
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
# just ensure the tower exists. won't rebuild # just ensure the tower exists. won't rebuild
self._tower_builder.build(tower) self._tower_builder.build(tower)
......
...@@ -94,8 +94,8 @@ def shape4d(a, data_format='NHWC'): ...@@ -94,8 +94,8 @@ def shape4d(a, data_format='NHWC'):
a: a int or tuple/list of length 2 a: a int or tuple/list of length 2
Returns: Returns:
list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]`` or ``[1, list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]``
1, a, a]`` depending on data_format. or ``[1, 1, a, a]`` depending on data_format.
""" """
s2d = shape2d(a) s2d = shape2d(a)
if data_format == 'NHWC': if data_format == 'NHWC':
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import sys import sys
__all__ = ['enable_call_trace']
def enable_call_trace(): def enable_call_trace():
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: dependency.py # File: develop.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Utilities to handle dependency """ """ Utilities for developers only.
These are not visible to users (not automatically imported). And should not
__all__ = ['create_dummy_func', 'create_dummy_class'] appeared in docs."""
import os
import functools
from datetime import datetime
from . import logger
def create_dummy_class(klass, dependency): def create_dummy_class(klass, dependency):
...@@ -40,3 +44,73 @@ def create_dummy_func(func, dependency): ...@@ -40,3 +44,73 @@ def create_dummy_func(func, dependency):
def _dummy(*args, **kwargs): def _dummy(*args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func)) raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func))
return _dummy return _dummy
def building_rtfd():
"""
Returns:
bool: if tensorpack is being imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
def log_deprecated(name="", text="", eos=""):
"""
Log deprecation warning.
Args:
name (str): name of the deprecated item.
text (str, optional): information about the deprecation.
eos (str, optional): end of service date such as "YYYY-MM-DD".
"""
assert name or text
if eos:
eos = "after " + datetime(*map(int, eos.split("-"))).strftime("%d %b")
if name:
if eos:
warn_msg = "%s will be deprecated %s. %s" % (name, eos, text)
else:
warn_msg = "%s was deprecated. %s" % (name, text)
else:
warn_msg = text
if eos:
warn_msg += " Legacy period ends %s" % eos
logger.warn("[Deprecated] " + warn_msg)
def deprecated(text="", eos=""):
"""
Args:
text, eos: same as :func:`log_deprecated`.
Returns:
a decorator which deprecates the function.
Example:
.. code-block:: python
@deprecated("Explanation of what to do instead.", "2017-11-4")
def foo(...):
pass
"""
def get_location():
import inspect
frame = inspect.currentframe()
if frame:
callstack = inspect.getouterframes(frame)[-1]
return '%s:%i' % (callstack[1], callstack[2])
else:
stack = inspect.stack(0)
entry = stack[2]
return '%s:%i' % (entry[1], entry[2])
def deprecated_inner(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
name = "{} [{}]".format(func.__name__, get_location())
log_deprecated(name, text, eos)
return func(*args, **kwargs)
return new_func
return deprecated_inner
...@@ -9,8 +9,6 @@ import inspect ...@@ -9,8 +9,6 @@ import inspect
from datetime import datetime from datetime import datetime
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import functools
from . import logger
__all__ = ['change_env', __all__ = ['change_env',
...@@ -18,9 +16,6 @@ __all__ = ['change_env', ...@@ -18,9 +16,6 @@ __all__ = ['change_env',
'get_tqdm_kwargs', 'get_tqdm_kwargs',
'get_tqdm', 'get_tqdm',
'execute_only_once', 'execute_only_once',
'building_rtfd',
'log_deprecated',
'deprecated'
] ]
...@@ -110,73 +105,3 @@ def get_tqdm(**kwargs): ...@@ -110,73 +105,3 @@ def get_tqdm(**kwargs):
""" Similar to :func:`get_tqdm_kwargs`, but returns the tqdm object """ Similar to :func:`get_tqdm_kwargs`, but returns the tqdm object
directly. """ directly. """
return tqdm(**get_tqdm_kwargs(**kwargs)) return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd():
"""
Returns:
bool: if tensorpack is being imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
def log_deprecated(name="", text="", eos=""):
"""
Log deprecation warning.
Args:
name (str): name of the deprecated item.
text (str, optional): information about the deprecation.
eos (str, optional): end of service date such as "YYYY-MM-DD".
"""
assert name or text
if eos:
eos = "after " + datetime(*map(int, eos.split("-"))).strftime("%d %b")
if name:
if eos:
warn_msg = "%s will be deprecated %s. %s" % (name, eos, text)
else:
warn_msg = "%s was deprecated. %s" % (name, text)
else:
warn_msg = text
if eos:
warn_msg += " Legacy period ends %s" % eos
logger.warn("[Deprecated] " + warn_msg)
def deprecated(text="", eos=""):
"""
Args:
text, eos: same as :func:`log_deprecated`.
Returns:
a decorator which deprecates the function.
Example:
.. code-block:: python
@deprecated("Explanation of what to do instead.", "2017-11-4")
def foo(...):
pass
"""
def get_location():
import inspect
frame = inspect.currentframe()
if frame:
callstack = inspect.getouterframes(frame)[-1]
return '%s:%i' % (callstack[1], callstack[2])
else:
stack = inspect.stack(0)
entry = stack[2]
return '%s:%i' % (entry[1], entry[2])
def deprecated_inner(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
name = "{} [{}]".format(func.__name__, get_location())
log_deprecated(name, text, eos)
return func(*args, **kwargs)
return new_func
return deprecated_inner
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment