Commit 23e4f928 authored by Yuxin Wu's avatar Yuxin Wu

hide some internal functions from import

parent ceb004a1
......@@ -8,7 +8,7 @@ import six
from six.moves import zip
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name
from ..tfutils.common import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats']
......
......@@ -11,7 +11,7 @@ import os
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_tensor_name
from ..tfutils.common import get_op_tensor_name
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter',
......
......@@ -7,7 +7,7 @@ import numpy as np
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_tensor_name
from ..tfutils.common import get_op_tensor_name
__all__ = ['SendStat', 'DumpParamAsImage', 'InjectShell']
......
......@@ -15,7 +15,7 @@ from six.moves import range, zip
from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
......
......@@ -10,7 +10,7 @@ import tensorflow as tf
from ..utils import logger
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.model_utils import describe_model
from ..tfutils.model_utils import describe_trainable_vars
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
......@@ -41,7 +41,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
self.predictor = OfflinePredictor(self.config)
if self.idx == 0:
with self.predictor.graph.as_default():
describe_model()
describe_trainable_vars()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
......
......@@ -10,10 +10,10 @@ from ..utils.argtools import graph_memoized
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
'get_op_tensor_name',
'get_tensors_by_names',
'get_op_or_tensor_by_name',
'get_tf_version_number',
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
# 'get_tf_version_number',
]
......
......@@ -8,10 +8,10 @@ from tabulate import tabulate
from ..utils import logger
__all__ = ['describe_model', 'get_shape_str']
__all__ = []
def describe_model():
def describe_trainable_vars():
"""
Print a description of the current model parameters.
Skip variables starting with "tower".
......
......@@ -12,7 +12,7 @@ if six.PY2:
else:
import functools
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
__all__ = ['auto_reuse_variable_scope']
@deprecated("Use tf.get_default_graph().get_name_scope() (available since 1.2.1).")
......
......@@ -16,8 +16,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
__all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary']
__all__ = ['create_scalar_summary', 'add_param_summary',
'add_activation_summary', 'add_moving_summary']
def create_scalar_summary(name, v):
......
......@@ -6,6 +6,8 @@ import tensorflow as tf
from contextlib import contextmanager
import numpy as np
__all__ = []
# this function exists for backwards-compatibilty
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
......
......@@ -13,7 +13,7 @@ from ..utils import logger
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_name',
# 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path']
......
......@@ -16,7 +16,7 @@ from ..utils import logger
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
......@@ -127,7 +127,7 @@ class Trainer(object):
self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors)
describe_model()
describe_trainable_vars()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment