Commit e8e8b014 authored by Yuxin Wu's avatar Yuxin Wu

address deprecation of tf.contrib

parent 0ffd5337
......@@ -48,7 +48,7 @@ A simple example of how it works:
```python
pred_config = PredictConfig(
model=YourModel(),
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['input1', 'input2'], # tensor names in the graph, or name of the declared inputs
output_names=['output1', 'output2']) # tensor names in the graph
predictor = OfflinePredictor(pred_config)
......@@ -161,7 +161,7 @@ with TowerContext('', is_training=False):
### Step 2: load the checkpoint
You can just use `tf.train.Saver` for all the work.
Alternatively, use tensorpack's `get_model_loader(path).init(tf.get_default_session())`
Alternatively, use tensorpack's `SmartInit(path).init(tf.get_default_session())`
Now, you've already built a graph for inference, and the checkpoint is also loaded.
You may now:
......
......@@ -7,7 +7,7 @@ import numpy as np
import tensorflow as tf
from tensorpack import get_default_sess_config, get_op_tensor_name
from tensorpack.tfutils.sessinit import get_model_loader
from tensorpack.tfutils.sessinit import SmartInit
from tensorpack.utils import logger
if __name__ == '__main__':
......@@ -25,7 +25,7 @@ if __name__ == '__main__':
tf.train.import_meta_graph(args.meta, clear_devices=True)
G = tf.get_default_graph()
with tf.Session(config=get_default_sess_config()) as sess:
init = get_model_loader(args.model)
init = SmartInit(args.model)
init.init(sess)
feed = {}
......@@ -52,16 +52,20 @@ if __name__ == '__main__':
sess.run(fetches, feed_dict=feed, options=opt, run_metadata=meta)
if args.print_flops:
tf.contrib.tfprof.model_analyzer.print_model_analysis(
G, run_meta=meta,
tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
tf.profiler.profile(
G,
run_meta=meta,
cmd='op',
options=tf.profiler.ProfileOptionBuilder.float_operation())
if args.print_params:
tf.contrib.tfprof.model_analyzer.print_model_analysis(
G, run_meta=meta,
tfprof_options=tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
tf.profiler.profile(
G,
run_meta=meta,
options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
if args.print_timing:
tf.contrib.tfprof.model_analyzer.print_model_analysis(
G, run_meta=meta,
tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)
tf.profiler.profile(
G,
run_meta=meta,
options=tf.profiler.ProfileOptionBuilder.time_and_memory())
......@@ -143,7 +143,7 @@ def allreduce_grads(all_grads, average):
"""
if get_tf_version_tuple() <= (1, 12):
from tensorflow.contrib import nccl
from tensorflow.contrib import nccl # deprecated
else:
from tensorflow.python.ops import nccl_ops as nccl
nr_tower = len(all_grads)
......
......@@ -296,12 +296,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if TF_version <= (1, 12):
try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated
except Exception:
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated
else:
from tensorflow.python.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
......
......@@ -50,7 +50,7 @@ def Conv2D(
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
dilation_rate = shape2d(dilation_rate)
......@@ -175,7 +175,7 @@ def Conv2DTranspose(
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
......@@ -48,7 +48,7 @@ def FullyConnected(
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
......@@ -22,8 +22,8 @@ def _log_once(msg):
if get_tf_version_tuple() <= (1, 12):
l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer
l2_regularizer = tf.contrib.layers.l2_regularizer # deprecated
l1_regularizer = tf.contrib.layers.l1_regularizer # deprecated
else:
# oh these little dirty details
l2_regularizer = lambda x: tf.keras.regularizers.l2(x * 0.5) # noqa
......
......@@ -32,7 +32,10 @@ def dependency_of_targets(targets, op):
op = op.op
assert isinstance(op, tf.Operation), op
from tensorflow.contrib.graph_editor import get_backward_walk_ops
try:
from tensorflow.contrib.graph_editor import get_backward_walk_ops # deprecated
except ImportError:
from tensorflow.python.ops.op_selector import get_backward_walk_ops
# alternative implementation can use graph_util.extract_sub_graph
dependent_ops = get_backward_walk_ops(targets, control_inputs=True)
return op in dependent_ops
......
......@@ -2,7 +2,7 @@
# File: sesscreate.py
from ..compat import tfv1 as tf, is_tfv2
from ..compat import tfv1 as tf
from ..utils import logger
from .common import get_default_sess_config
......@@ -70,16 +70,19 @@ class NewSessionCreator(tf.train.SessionCreator):
return False
def run(op):
if not is_tfv2():
from tensorflow.contrib.graph_editor import get_backward_walk_ops
deps = get_backward_walk_ops(op, control_inputs=True)
for dep_op in deps:
if blocking_op(dep_op):
logger.warn(
"Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!".format(
op.name, dep_op.name))
try:
from tensorflow.contrib.graph_editor import get_backward_walk_ops # deprecated
except ImportError:
from tensorflow.python.ops.op_selector import get_backward_walk_ops
deps = get_backward_walk_ops(op, control_inputs=True)
for dep_op in deps:
if blocking_op(dep_op):
logger.warn(
"Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!".format(
op.name, dep_op.name))
sess.run(op)
run(tf.global_variables_initializer())
......
......@@ -140,7 +140,7 @@ class SessionUpdate(object):
def dump_session_params(path):
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npz format (loadable by :func:`sessinit.get_model_loader`).
npz format (loadable by :func:`sessinit.SmartInit`).
Args:
path(str): the file name to save the parameters. Must ends with npz.
......
......@@ -279,7 +279,10 @@ class SingleCostTrainer(TowerTrainer):
if not self.XLA_COMPILE:
return compute_grad_from_inputs(*inputs)
else:
from tensorflow.contrib.compiler import xla
try:
from tensorflow.contrib.compiler import xla # deprecated
except ImportError:
from tensorflow.python.compiler.xla import xla
def xla_func():
grads = compute_grad_from_inputs(*inputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment