Commit 3d1a30ff authored by Yuxin Wu's avatar Yuxin Wu

Move more core code to tf.compat.v1

parent 4057c531
......@@ -13,7 +13,7 @@ import tqdm
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost
from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
......@@ -339,7 +339,7 @@ class ImageNetModel(ModelDesc):
if self.weight_decay > 0:
wd_loss = regularize_cost(self.weight_decay_pattern,
tf.contrib.layers.l2_regularizer(self.weight_decay),
l2_regularizer(self.weight_decay),
name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
total_cost = tf.add_n([loss, wd_loss], name='cost')
......
......@@ -98,7 +98,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, label, 1)), tf.float32, name='accuracy')
tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32), name='accuracy')
wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss),
......
......@@ -10,6 +10,7 @@ import six
import tensorflow as tf
from six.moves import range, zip
from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext
......@@ -101,7 +102,7 @@ class DataParallelBuilder(GraphBuilder):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False
reuse = not usevs and idx > 0
with tf.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
with tfv1.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
tower_names[idx],
vs_name=tower_names[idx] if usevs else '',
index=idx, total=len(towers)):
......
......@@ -6,6 +6,7 @@ import operator
from contextlib import contextmanager
import tensorflow as tf
from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.scope_utils import cached_name_scope, under_name_scope
from ..tfutils.varreplace import custom_getter_scope
......@@ -82,7 +83,7 @@ class LeastLoadedDeviceSetter(object):
# from tensorflow.python.training.device_util import canonicalize
# from tensorflow.python.distribute.device_util import canonicalize
def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
return tfv1.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from ..compat import tfv1
from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register
......@@ -50,8 +51,8 @@ def PReLU(x, init=0.001, name='output'):
* ``alpha``: learnable slope.
"""
init = tf.constant_initializer(init)
alpha = tf.get_variable('alpha', [], initializer=init)
init = tfv1.constant_initializer(init)
alpha = tfv1.get_variable('alpha', [], initializer=init)
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
ret = tf.multiply(x, 0.5, name=name)
......
......@@ -25,7 +25,8 @@ if get_tf_version_tuple() <= (1, 12):
l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer
else:
l2_regularizer = tf.keras.regularizers.l2
# oh these little dirty details
l2_regularizer = lambda x: tf.keras.regularizers.l2(x * 0.5) # noqa
l1_regularizer = tf.keras.regularizers.l1
......
......@@ -8,6 +8,7 @@ import six
import tensorflow as tf
from six.moves import queue, range
from ..compat import tfv1
from ..tfutils.model_utils import describe_trainable_vars
from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
......@@ -162,7 +163,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def start(self):
if self._need_default_sess:
assert tf.get_default_session() is not None, \
assert tfv1.get_default_session() is not None, \
"Not session is bind to predictors, " \
"MultiThreadAsyncPredictor.start() has to be called under a default session!"
for t in self.threads:
......
......@@ -3,7 +3,7 @@
import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config
......
......@@ -6,7 +6,9 @@ from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from inspect import getmembers, isfunction
import tensorflow as tf
from ..compat import is_tfv2
from ..utils import logger
from .tower import get_current_tower_context
......@@ -138,6 +140,8 @@ def enable_argscope_for_module(module, log_shape=True):
Args:
log_shape (bool): print input/output shapes of each function.
"""
if is_tfv2() and module == tf.layers:
module = tf.compat.v1.layers
for name, obj in getmembers(module):
if isfunction(obj):
setattr(module, name, enable_argscope_for_function(obj,
......
......@@ -12,6 +12,7 @@ from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
from ..compat import is_tfv2, tfv1
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple
from ..tfutils.tower import PredictTowerContext
......@@ -60,7 +61,7 @@ class ModelExporter(object):
self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess)
dtypes = [n.dtype for n in input_tensors]
......@@ -88,7 +89,7 @@ class ModelExporter(object):
logger.info("Output graph written to {}.".format(filename))
def export_serving(self, filename,
tags=[tf.saved_model.tag_constants.SERVING],
tags=[tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING],
signature_name='prediction_pipeline'):
"""
Converts a checkpoint and graph to a servable for TensorFlow Serving.
......@@ -121,21 +122,22 @@ class ModelExporter(object):
self.config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(self.config.input_names)
inputs_signatures = {t.name: tf.saved_model.utils.build_tensor_info(t) for t in input_tensors}
saved_model = tfv1.saved_model.utils
inputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in input_tensors}
output_tensors = get_tensors_by_names(self.config.output_names)
outputs_signatures = {t.name: tf.saved_model.utils.build_tensor_info(t) for t in output_tensors}
outputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in output_tensors}
self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess)
builder = tf.saved_model.builder.SavedModelBuilder(filename)
builder = tfv1.saved_model.builder.SavedModelBuilder(filename)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
prediction_signature = tfv1.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_signatures,
outputs=outputs_signatures,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess, tags,
......
......@@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod
import six
import tensorflow as tf
from ..compat import tfv1
from ..utils import logger
from .summary import add_moving_summary
from .symbolic_functions import print_stat, rms
......@@ -40,11 +41,11 @@ class GradientProcessor(object):
# reuse the old name_scope, if process() is called multiple times
if self._name_scope is None:
with tf.name_scope(type(self).__name__) as scope:
with tfv1.name_scope(type(self).__name__) as scope:
self._name_scope = scope
return self._process(grads)
else:
with tf.name_scope(self._name_scope):
with tfv1.name_scope(self._name_scope):
return self._process(grads)
@abstractmethod
......@@ -175,7 +176,7 @@ class SummaryGradient(MapGradient):
return grad
if name not in SummaryGradient._summaried_gradient:
SummaryGradient._summaried_gradient.add(name)
tf.summary.histogram(name + '-grad', grad, collections=self._coll)
tfv1.summary.histogram(name + '-grad', grad, collections=self._coll)
add_moving_summary(rms(grad, name=name + '/rms'))
return grad
......
......@@ -20,7 +20,7 @@ class ProxyOptimizer(tfv1.train.Optimizer):
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
def __init__(self, opt, name='ProxyOptimizer'):
assert isinstance(opt, tf.train.Optimizer), opt
assert isinstance(opt, tfv1.train.Optimizer), opt
super(ProxyOptimizer, self).__init__(False, name)
self._opt = opt
......
......@@ -4,8 +4,8 @@
import os
import numpy as np
import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger
from .common import get_op_tensor_name
from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name
......
......@@ -7,6 +7,7 @@ import pprint
import six
import tensorflow as tf
from ..compat import tfv1
from ..utils import logger
from .common import get_op_tensor_name
......@@ -84,7 +85,7 @@ class SessionUpdate(object):
return None
if hasattr(value, 'dtype'):
vartype = var.value().dtype
vartype = var.dtype
if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype)
......@@ -172,7 +173,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
assert tf.gfile.Exists(model_path), model_path
assert tfv1.gfile.Exists(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
......@@ -186,7 +187,7 @@ def get_checkpoint_path(model_path):
logger.info(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path
assert tf.gfile.Exists(model_path) or tf.gfile.Exists(model_path + '.index'), model_path
assert tfv1.gfile.Exists(model_path) or tfv1.gfile.Exists(model_path + '.index'), model_path
return model_path
......@@ -200,7 +201,7 @@ def load_chkpt_vars(model_path):
dict: a name:value dict
"""
model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path)
reader = tfv1.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
for n in var_names:
......
......@@ -10,7 +10,7 @@ exclude = .git,
examples,
docs/conf.py
snippet,
examples-old,
examples_v2,
_test.py,
[isort]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment