Commit 23e4f928 authored by Yuxin Wu's avatar Yuxin Wu

hide some internal functions from import

parent ceb004a1
...@@ -8,7 +8,7 @@ import six ...@@ -8,7 +8,7 @@ import six
from six.moves import zip from six.moves import zip
from ..utils.stats import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer', __all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats'] 'ClassificationError', 'BinaryClassificationStats']
......
...@@ -11,7 +11,7 @@ import os ...@@ -11,7 +11,7 @@ import os
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam', __all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter', 'HyperParamSetter', 'HumanHyperParamSetter',
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell'] __all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell']
......
...@@ -15,7 +15,7 @@ from six.moves import range, zip ...@@ -15,7 +15,7 @@ from six.moves import range, zip
from .input_source_base import InputSource from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
......
...@@ -10,7 +10,7 @@ import tensorflow as tf ...@@ -10,7 +10,7 @@ import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_trainable_vars
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
...@@ -41,7 +41,7 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -41,7 +41,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
self.predictor = OfflinePredictor(self.config) self.predictor = OfflinePredictor(self.config)
if self.idx == 0: if self.idx == 0:
with self.predictor.graph.as_default(): with self.predictor.graph.as_default():
describe_model() describe_trainable_vars()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
......
...@@ -10,10 +10,10 @@ from ..utils.argtools import graph_memoized ...@@ -10,10 +10,10 @@ from ..utils.argtools import graph_memoized
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_op_tensor_name', # 'get_op_tensor_name',
'get_tensors_by_names', # 'get_tensors_by_names',
'get_op_or_tensor_by_name', # 'get_op_or_tensor_by_name',
'get_tf_version_number', # 'get_tf_version_number',
] ]
......
...@@ -8,10 +8,10 @@ from tabulate import tabulate ...@@ -8,10 +8,10 @@ from tabulate import tabulate
from ..utils import logger from ..utils import logger
__all__ = ['describe_model', 'get_shape_str'] __all__ = []
def describe_model(): def describe_trainable_vars():
""" """
Print a description of the current model parameters. Print a description of the current model parameters.
Skip variables starting with "tower". Skip variables starting with "tower".
......
...@@ -12,7 +12,7 @@ if six.PY2: ...@@ -12,7 +12,7 @@ if six.PY2:
else: else:
import functools import functools
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope'] __all__ = ['auto_reuse_variable_scope']
@deprecated("Use tf.get_default_graph().get_name_scope() (available since 1.2.1).") @deprecated("Use tf.get_default_graph().get_name_scope() (available since 1.2.1).")
......
...@@ -16,8 +16,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY ...@@ -16,8 +16,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context from .tower import get_current_tower_context
from .symbolic_functions import rms from .symbolic_functions import rms
__all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_scalar_summary', 'add_param_summary',
'add_moving_summary'] 'add_activation_summary', 'add_moving_summary']
def create_scalar_summary(name, v): def create_scalar_summary(name, v):
......
...@@ -6,6 +6,8 @@ import tensorflow as tf ...@@ -6,6 +6,8 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np import numpy as np
__all__ = []
# this function exists for backwards-compatibilty # this function exists for backwards-compatibilty
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
......
...@@ -13,7 +13,7 @@ from ..utils import logger ...@@ -13,7 +13,7 @@ from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_name', # 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path'] 'get_checkpoint_path']
......
...@@ -16,7 +16,7 @@ from ..utils import logger ...@@ -16,7 +16,7 @@ from ..utils import logger
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
...@@ -127,7 +127,7 @@ class Trainer(object): ...@@ -127,7 +127,7 @@ class Trainer(object):
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors) self.register_callback(self.monitors)
describe_model() describe_trainable_vars()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment