Commit a47c9980 authored by Yuxin Wu's avatar Yuxin Wu

hide some internal functions to develop.py.

parent 3e61aacd
tensorpack.tfutils package
==========================
tensorpack.tfutils.collection module
------------------------------------
.. automodule:: tensorpack.tfutils.collection
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.distributions module
---------------------------------------
.. automodule:: tensorpack.tfutils.distributions
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.gradproc module
------------------------------------
.. automodule:: tensorpack.tfutils.gradproc
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.modelutils module
------------------------------------
......@@ -9,6 +33,22 @@ tensorpack.tfutils.modelutils module
:undoc-members:
:show-inheritance:
tensorpack.tfutils.optimizer module
------------------------------------
.. automodule:: tensorpack.tfutils.optimizer
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.sesscreate module
------------------------------------
.. automodule:: tensorpack.tfutils.sesscreate
:members:
:undoc-members:
:show-inheritance:
tensorpack.tfutils.summary module
---------------------------------
......
......@@ -17,13 +17,6 @@ tensorpack.utils.concurrency module
:undoc-members:
:show-inheritance:
tensorpack.utils.debug module
-----------------------------
.. automodule:: tensorpack.utils.debug
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.discretize module
----------------------------------
......
......@@ -9,7 +9,7 @@ from collections import defaultdict
import six
from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
__all__ = ['RLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
......
......@@ -79,7 +79,7 @@ try:
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
except ImportError:
from ..utils.dependency import create_dummy_class
from ..utils.develop import create_dummy_class
GymEnv = create_dummy_class('GymEnv', 'gym') # noqa
......
......@@ -91,7 +91,7 @@ class BSDS500(RNGDataFlow):
try:
from scipy.io import loadmat
except ImportError:
from ...utils.dependency import create_dummy_class
from ...utils.develop import create_dummy_class
BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa
if __name__ == '__main__':
......
......@@ -72,7 +72,7 @@ class SVHNDigit(RNGDataFlow):
try:
import scipy.io
except ImportError:
from ...utils.dependency import create_dummy_class
from ...utils.develop import create_dummy_class
SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa
if __name__ == '__main__':
......
......@@ -82,7 +82,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
try:
import lmdb
except ImportError:
from ..utils.dependency import create_dummy_func
from ..utils.develop import create_dummy_func
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
......
......@@ -181,8 +181,8 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
a :class:`LMDBDataDecoder` instance.
Example:
.. code-block:: python
.. code-block:: none
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
"""
......@@ -226,7 +226,7 @@ class SVMLightData(RNGDataFlow):
yield [self.X[id, :], self.y[id]]
from ..utils.dependency import create_dummy_class # noqa
from ..utils.develop import create_dummy_class # noqa
try:
import h5py
except ImportError:
......
......@@ -10,7 +10,8 @@ import copy
from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary
from ..utils import logger, building_rtfd
from ..utils import logger
from ..utils.develop import building_rtfd
# make sure each layer is only logged once
_LAYER_LOGGED = set()
......
......@@ -8,7 +8,9 @@ import tensorflow as tf
import pickle
import six
from ..utils import logger, INPUTS_KEY, deprecated, log_deprecated
from ..utils import logger
from ..utils.naming import INPUTS_KEY
from ..utils.develop import deprecated, log_deprecated
from ..utils.argtools import memoized
from ..tfutils.modelutils import apply_slim_collections
......
......@@ -7,7 +7,8 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf
import six
from ..utils import logger, deprecated
from ..utils import logger
from ..utils.develop import deprecated
from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext
......@@ -60,8 +61,10 @@ class PredictorBase(object):
@abstractmethod
def _do_call(self, dp):
"""
:param dp: input datapoint. must have the same length as input_names
:return: output as defined by the config
Args:
dp: input datapoint. must have the same length as input_names
Returns:
output as defined by the config
"""
......
......@@ -8,7 +8,8 @@ import six
from six.moves import queue, range
import tensorflow as tf
from ..utils import logger, deprecated
from ..utils import logger
from ..utils.develop import deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.modelutils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
......@@ -185,5 +186,5 @@ try:
else:
from concurrent.futures import Future
except ImportError:
from ..utils.dependency import create_dummy_class
from ..utils.develop import create_dummy_class
MultiThreadAsyncPredictor = create_dummy_class('MultiThreadAsyncPredictor', 'tornado.concurrent') # noqa
......@@ -6,7 +6,7 @@ import tensorflow as tf
import six
from ..models import ModelDesc
from ..utils import log_deprecated
from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSession
......
......@@ -76,9 +76,10 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor.
Its input is: [input[0] in tower[0], input[1] in tower[0], ...,
input[0] in tower[1], input[1] in tower[1], ...]
"""
A data-parallel predictor.
Note that it doesn't split/concat inputs/outputs automatically.
Its input is: ``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
And same for the output.
"""
......
......@@ -6,7 +6,8 @@ import six
import tensorflow as tf
import re
from ..utils import log_deprecated, logger
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
......
......@@ -12,7 +12,8 @@ from six.moves import range
import tensorflow as tf
from .predict import PredictorFactory
from .config import TrainConfig
from ..utils import logger, deprecated, log_deprecated
from ..utils import logger
from ..utils.develop import deprecated, log_deprecated
from ..callbacks import StatHolder
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
......
......@@ -10,7 +10,8 @@ from ..callbacks import (
MaintainStepCounter)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger, log_deprecated
from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.optimizer import apply_grad_processors
......@@ -130,6 +131,8 @@ class TrainConfig(object):
self.predict_tower = predict_tower
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
assert len(set(self.predict_tower)) == len(self.predict_tower), \
"Cannot have duplicated predict_tower!"
if 'optimizer' in kwargs:
log_deprecated("TrainConfig(optimizer=...)",
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from ..utils import log_deprecated
from ..utils.develop import log_deprecated
from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_data import QueueInput, FeedfreeInput
......
......@@ -8,7 +8,8 @@ import itertools
import re
from six.moves import zip, range
from ..utils import logger, log_deprecated
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext
......
......@@ -33,7 +33,7 @@ class PredictorFactory(object):
Returns:
an online predictor (which has to be used under a default session)
"""
tower = self.towers[tower] # TODO is it good?
tower = self.towers[tower]
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
# just ensure the tower exists. won't rebuild
self._tower_builder.build(tower)
......
......@@ -94,8 +94,8 @@ def shape4d(a, data_format='NHWC'):
a: a int or tuple/list of length 2
Returns:
list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]`` or ``[1,
1, a, a]`` depending on data_format.
list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]``
or ``[1, 1, a, a]`` depending on data_format.
"""
s2d = shape2d(a)
if data_format == 'NHWC':
......
......@@ -5,7 +5,6 @@
import sys
__all__ = ['enable_call_trace']
def enable_call_trace():
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dependency.py
# File: develop.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Utilities to handle dependency """
__all__ = ['create_dummy_func', 'create_dummy_class']
""" Utilities for developers only.
These are not visible to users (not automatically imported). And should not
appeared in docs."""
import os
import functools
from datetime import datetime
from . import logger
def create_dummy_class(klass, dependency):
......@@ -40,3 +44,73 @@ def create_dummy_func(func, dependency):
def _dummy(*args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func))
return _dummy
def building_rtfd():
"""
Returns:
bool: if tensorpack is being imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
def log_deprecated(name="", text="", eos=""):
"""
Log deprecation warning.
Args:
name (str): name of the deprecated item.
text (str, optional): information about the deprecation.
eos (str, optional): end of service date such as "YYYY-MM-DD".
"""
assert name or text
if eos:
eos = "after " + datetime(*map(int, eos.split("-"))).strftime("%d %b")
if name:
if eos:
warn_msg = "%s will be deprecated %s. %s" % (name, eos, text)
else:
warn_msg = "%s was deprecated. %s" % (name, text)
else:
warn_msg = text
if eos:
warn_msg += " Legacy period ends %s" % eos
logger.warn("[Deprecated] " + warn_msg)
def deprecated(text="", eos=""):
"""
Args:
text, eos: same as :func:`log_deprecated`.
Returns:
a decorator which deprecates the function.
Example:
.. code-block:: python
@deprecated("Explanation of what to do instead.", "2017-11-4")
def foo(...):
pass
"""
def get_location():
import inspect
frame = inspect.currentframe()
if frame:
callstack = inspect.getouterframes(frame)[-1]
return '%s:%i' % (callstack[1], callstack[2])
else:
stack = inspect.stack(0)
entry = stack[2]
return '%s:%i' % (entry[1], entry[2])
def deprecated_inner(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
name = "{} [{}]".format(func.__name__, get_location())
log_deprecated(name, text, eos)
return func(*args, **kwargs)
return new_func
return deprecated_inner
......@@ -9,8 +9,6 @@ import inspect
from datetime import datetime
from tqdm import tqdm
import numpy as np
import functools
from . import logger
__all__ = ['change_env',
......@@ -18,9 +16,6 @@ __all__ = ['change_env',
'get_tqdm_kwargs',
'get_tqdm',
'execute_only_once',
'building_rtfd',
'log_deprecated',
'deprecated'
]
......@@ -110,73 +105,3 @@ def get_tqdm(**kwargs):
""" Similar to :func:`get_tqdm_kwargs`, but returns the tqdm object
directly. """
return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd():
"""
Returns:
bool: if tensorpack is being imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
def log_deprecated(name="", text="", eos=""):
"""
Log deprecation warning.
Args:
name (str): name of the deprecated item.
text (str, optional): information about the deprecation.
eos (str, optional): end of service date such as "YYYY-MM-DD".
"""
assert name or text
if eos:
eos = "after " + datetime(*map(int, eos.split("-"))).strftime("%d %b")
if name:
if eos:
warn_msg = "%s will be deprecated %s. %s" % (name, eos, text)
else:
warn_msg = "%s was deprecated. %s" % (name, text)
else:
warn_msg = text
if eos:
warn_msg += " Legacy period ends %s" % eos
logger.warn("[Deprecated] " + warn_msg)
def deprecated(text="", eos=""):
"""
Args:
text, eos: same as :func:`log_deprecated`.
Returns:
a decorator which deprecates the function.
Example:
.. code-block:: python
@deprecated("Explanation of what to do instead.", "2017-11-4")
def foo(...):
pass
"""
def get_location():
import inspect
frame = inspect.currentframe()
if frame:
callstack = inspect.getouterframes(frame)[-1]
return '%s:%i' % (callstack[1], callstack[2])
else:
stack = inspect.stack(0)
entry = stack[2]
return '%s:%i' % (entry[1], entry[2])
def deprecated_inner(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
name = "{} [{}]".format(func.__name__, get_location())
log_deprecated(name, text, eos)
return func(*args, **kwargs)
return new_func
return deprecated_inner
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment