Commit e8e8b014 authored by Yuxin Wu's avatar Yuxin Wu

address deprecation of tf.contrib

parent 0ffd5337
...@@ -48,7 +48,7 @@ A simple example of how it works: ...@@ -48,7 +48,7 @@ A simple example of how it works:
```python ```python
pred_config = PredictConfig( pred_config = PredictConfig(
model=YourModel(), model=YourModel(),
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
input_names=['input1', 'input2'], # tensor names in the graph, or name of the declared inputs input_names=['input1', 'input2'], # tensor names in the graph, or name of the declared inputs
output_names=['output1', 'output2']) # tensor names in the graph output_names=['output1', 'output2']) # tensor names in the graph
predictor = OfflinePredictor(pred_config) predictor = OfflinePredictor(pred_config)
...@@ -161,7 +161,7 @@ with TowerContext('', is_training=False): ...@@ -161,7 +161,7 @@ with TowerContext('', is_training=False):
### Step 2: load the checkpoint ### Step 2: load the checkpoint
You can just use `tf.train.Saver` for all the work. You can just use `tf.train.Saver` for all the work.
Alternatively, use tensorpack's `get_model_loader(path).init(tf.get_default_session())` Alternatively, use tensorpack's `SmartInit(path).init(tf.get_default_session())`
Now, you've already built a graph for inference, and the checkpoint is also loaded. Now, you've already built a graph for inference, and the checkpoint is also loaded.
You may now: You may now:
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorpack import get_default_sess_config, get_op_tensor_name from tensorpack import get_default_sess_config, get_op_tensor_name
from tensorpack.tfutils.sessinit import get_model_loader from tensorpack.tfutils.sessinit import SmartInit
from tensorpack.utils import logger from tensorpack.utils import logger
if __name__ == '__main__': if __name__ == '__main__':
...@@ -25,7 +25,7 @@ if __name__ == '__main__': ...@@ -25,7 +25,7 @@ if __name__ == '__main__':
tf.train.import_meta_graph(args.meta, clear_devices=True) tf.train.import_meta_graph(args.meta, clear_devices=True)
G = tf.get_default_graph() G = tf.get_default_graph()
with tf.Session(config=get_default_sess_config()) as sess: with tf.Session(config=get_default_sess_config()) as sess:
init = get_model_loader(args.model) init = SmartInit(args.model)
init.init(sess) init.init(sess)
feed = {} feed = {}
...@@ -52,16 +52,20 @@ if __name__ == '__main__': ...@@ -52,16 +52,20 @@ if __name__ == '__main__':
sess.run(fetches, feed_dict=feed, options=opt, run_metadata=meta) sess.run(fetches, feed_dict=feed, options=opt, run_metadata=meta)
if args.print_flops: if args.print_flops:
tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.profiler.profile(
G, run_meta=meta, G,
tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS) run_meta=meta,
cmd='op',
options=tf.profiler.ProfileOptionBuilder.float_operation())
if args.print_params: if args.print_params:
tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.profiler.profile(
G, run_meta=meta, G,
tfprof_options=tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS) run_meta=meta,
options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
if args.print_timing: if args.print_timing:
tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.profiler.profile(
G, run_meta=meta, G,
tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY) run_meta=meta,
options=tf.profiler.ProfileOptionBuilder.time_and_memory())
...@@ -143,7 +143,7 @@ def allreduce_grads(all_grads, average): ...@@ -143,7 +143,7 @@ def allreduce_grads(all_grads, average):
""" """
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
from tensorflow.contrib import nccl from tensorflow.contrib import nccl # deprecated
else: else:
from tensorflow.python.ops import nccl_ops as nccl from tensorflow.python.ops import nccl_ops as nccl
nr_tower = len(all_grads) nr_tower = len(all_grads)
......
...@@ -296,12 +296,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -296,12 +296,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if TF_version <= (1, 12): if TF_version <= (1, 12):
try: try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated
except Exception: except Exception:
pass pass
else: else:
_validate_and_load_nccl_so() _validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated
else: else:
from tensorflow.python.ops import gen_nccl_ops from tensorflow.python.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name) shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
......
...@@ -50,7 +50,7 @@ def Conv2D( ...@@ -50,7 +50,7 @@ def Conv2D(
""" """
if kernel_initializer is None: if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
dilation_rate = shape2d(dilation_rate) dilation_rate = shape2d(dilation_rate)
...@@ -175,7 +175,7 @@ def Conv2DTranspose( ...@@ -175,7 +175,7 @@ def Conv2DTranspose(
""" """
if kernel_initializer is None: if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
...@@ -48,7 +48,7 @@ def FullyConnected( ...@@ -48,7 +48,7 @@ def FullyConnected(
""" """
if kernel_initializer is None: if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
...@@ -22,8 +22,8 @@ def _log_once(msg): ...@@ -22,8 +22,8 @@ def _log_once(msg):
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
l2_regularizer = tf.contrib.layers.l2_regularizer l2_regularizer = tf.contrib.layers.l2_regularizer # deprecated
l1_regularizer = tf.contrib.layers.l1_regularizer l1_regularizer = tf.contrib.layers.l1_regularizer # deprecated
else: else:
# oh these little dirty details # oh these little dirty details
l2_regularizer = lambda x: tf.keras.regularizers.l2(x * 0.5) # noqa l2_regularizer = lambda x: tf.keras.regularizers.l2(x * 0.5) # noqa
......
...@@ -32,7 +32,10 @@ def dependency_of_targets(targets, op): ...@@ -32,7 +32,10 @@ def dependency_of_targets(targets, op):
op = op.op op = op.op
assert isinstance(op, tf.Operation), op assert isinstance(op, tf.Operation), op
from tensorflow.contrib.graph_editor import get_backward_walk_ops try:
from tensorflow.contrib.graph_editor import get_backward_walk_ops # deprecated
except ImportError:
from tensorflow.python.ops.op_selector import get_backward_walk_ops
# alternative implementation can use graph_util.extract_sub_graph # alternative implementation can use graph_util.extract_sub_graph
dependent_ops = get_backward_walk_ops(targets, control_inputs=True) dependent_ops = get_backward_walk_ops(targets, control_inputs=True)
return op in dependent_ops return op in dependent_ops
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: sesscreate.py # File: sesscreate.py
from ..compat import tfv1 as tf, is_tfv2 from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from .common import get_default_sess_config from .common import get_default_sess_config
...@@ -70,8 +70,10 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -70,8 +70,10 @@ class NewSessionCreator(tf.train.SessionCreator):
return False return False
def run(op): def run(op):
if not is_tfv2(): try:
from tensorflow.contrib.graph_editor import get_backward_walk_ops from tensorflow.contrib.graph_editor import get_backward_walk_ops # deprecated
except ImportError:
from tensorflow.python.ops.op_selector import get_backward_walk_ops
deps = get_backward_walk_ops(op, control_inputs=True) deps = get_backward_walk_ops(op, control_inputs=True)
for dep_op in deps: for dep_op in deps:
...@@ -80,6 +82,7 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -80,6 +82,7 @@ class NewSessionCreator(tf.train.SessionCreator):
"Initializer '{}' depends on a blocking op '{}'. " "Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!".format( "This initializer is likely to hang!".format(
op.name, dep_op.name)) op.name, dep_op.name))
sess.run(op) sess.run(op)
run(tf.global_variables_initializer()) run(tf.global_variables_initializer())
......
...@@ -140,7 +140,7 @@ class SessionUpdate(object): ...@@ -140,7 +140,7 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" """
Dump value of all TRAINABLE + MODEL variables to a dict, and save as Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npz format (loadable by :func:`sessinit.get_model_loader`). npz format (loadable by :func:`sessinit.SmartInit`).
Args: Args:
path(str): the file name to save the parameters. Must ends with npz. path(str): the file name to save the parameters. Must ends with npz.
......
...@@ -279,7 +279,10 @@ class SingleCostTrainer(TowerTrainer): ...@@ -279,7 +279,10 @@ class SingleCostTrainer(TowerTrainer):
if not self.XLA_COMPILE: if not self.XLA_COMPILE:
return compute_grad_from_inputs(*inputs) return compute_grad_from_inputs(*inputs)
else: else:
from tensorflow.contrib.compiler import xla try:
from tensorflow.contrib.compiler import xla # deprecated
except ImportError:
from tensorflow.python.compiler.xla import xla
def xla_func(): def xla_func():
grads = compute_grad_from_inputs(*inputs) grads = compute_grad_from_inputs(*inputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment