Commit 3d1a30ff authored by Yuxin Wu's avatar Yuxin Wu

Move more core code to tf.compat.v1

parent 4057c531
...@@ -13,7 +13,7 @@ import tqdm ...@@ -13,7 +13,7 @@ import tqdm
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger from tensorpack.utils import logger
...@@ -339,7 +339,7 @@ class ImageNetModel(ModelDesc): ...@@ -339,7 +339,7 @@ class ImageNetModel(ModelDesc):
if self.weight_decay > 0: if self.weight_decay > 0:
wd_loss = regularize_cost(self.weight_decay_pattern, wd_loss = regularize_cost(self.weight_decay_pattern,
tf.contrib.layers.l2_regularizer(self.weight_decay), l2_regularizer(self.weight_decay),
name='l2_regularize_loss') name='l2_regularize_loss')
add_moving_summary(loss, wd_loss) add_moving_summary(loss, wd_loss)
total_cost = tf.add_n([loss, wd_loss], name='cost') total_cost = tf.add_n([loss, wd_loss], name='cost')
......
...@@ -98,7 +98,7 @@ class Model(ModelDesc): ...@@ -98,7 +98,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, label, 1)), tf.float32, name='accuracy') tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32), name='accuracy')
wd_cost = tf.multiply(1e-5, wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
......
...@@ -10,6 +10,7 @@ import six ...@@ -10,6 +10,7 @@ import six
import tensorflow as tf import tensorflow as tf
from six.moves import range, zip from six.moves import range, zip
from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext from ..tfutils.tower import TrainTowerContext
...@@ -101,7 +102,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -101,7 +102,7 @@ class DataParallelBuilder(GraphBuilder):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False usevs = use_vs[idx] if use_vs is not None else False
reuse = not usevs and idx > 0 reuse = not usevs and idx > 0
with tf.device(device), _maybe_reuse_vs(reuse), TrainTowerContext( with tfv1.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
tower_names[idx], tower_names[idx],
vs_name=tower_names[idx] if usevs else '', vs_name=tower_names[idx] if usevs else '',
index=idx, total=len(towers)): index=idx, total=len(towers)):
......
...@@ -6,6 +6,7 @@ import operator ...@@ -6,6 +6,7 @@ import operator
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..tfutils.scope_utils import cached_name_scope, under_name_scope from ..tfutils.scope_utils import cached_name_scope, under_name_scope
from ..tfutils.varreplace import custom_getter_scope from ..tfutils.varreplace import custom_getter_scope
...@@ -82,7 +83,7 @@ class LeastLoadedDeviceSetter(object): ...@@ -82,7 +83,7 @@ class LeastLoadedDeviceSetter(object):
# from tensorflow.python.training.device_util import canonicalize # from tensorflow.python.training.device_util import canonicalize
# from tensorflow.python.distribute.device_util import canonicalize # from tensorflow.python.distribute.device_util import canonicalize
def canonicalize(name): # tensorflow/tensorflow#11484 def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string() return tfv1.DeviceSpec.from_string(name).to_string()
if op.device: if op.device:
return op.device return op.device
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from .batch_norm import BatchNorm from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
...@@ -50,8 +51,8 @@ def PReLU(x, init=0.001, name='output'): ...@@ -50,8 +51,8 @@ def PReLU(x, init=0.001, name='output'):
* ``alpha``: learnable slope. * ``alpha``: learnable slope.
""" """
init = tf.constant_initializer(init) init = tfv1.constant_initializer(init)
alpha = tf.get_variable('alpha', [], initializer=init) alpha = tfv1.get_variable('alpha', [], initializer=init)
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
ret = tf.multiply(x, 0.5, name=name) ret = tf.multiply(x, 0.5, name=name)
......
...@@ -25,7 +25,8 @@ if get_tf_version_tuple() <= (1, 12): ...@@ -25,7 +25,8 @@ if get_tf_version_tuple() <= (1, 12):
l2_regularizer = tf.contrib.layers.l2_regularizer l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer l1_regularizer = tf.contrib.layers.l1_regularizer
else: else:
l2_regularizer = tf.keras.regularizers.l2 # oh these little dirty details
l2_regularizer = lambda x: tf.keras.regularizers.l2(x * 0.5) # noqa
l1_regularizer = tf.keras.regularizers.l1 l1_regularizer = tf.keras.regularizers.l1
......
...@@ -8,6 +8,7 @@ import six ...@@ -8,6 +8,7 @@ import six
import tensorflow as tf import tensorflow as tf
from six.moves import queue, range from six.moves import queue, range
from ..compat import tfv1
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..utils import logger from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
...@@ -162,7 +163,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -162,7 +163,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def start(self): def start(self):
if self._need_default_sess: if self._need_default_sess:
assert tf.get_default_session() is not None, \ assert tfv1.get_default_session() is not None, \
"Not session is bind to predictors, " \ "Not session is bind to predictors, " \
"MultiThreadAsyncPredictor.start() has to be called under a default session!" "MultiThreadAsyncPredictor.start() has to be called under a default session!"
for t in self.threads: for t in self.threads:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import six import six
import tensorflow as tf from ..compat import tfv1 as tf
from ..graph_builder import ModelDescBase from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
......
...@@ -6,7 +6,9 @@ from collections import defaultdict ...@@ -6,7 +6,9 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from inspect import getmembers, isfunction from inspect import getmembers, isfunction
import tensorflow as tf
from ..compat import is_tfv2
from ..utils import logger from ..utils import logger
from .tower import get_current_tower_context from .tower import get_current_tower_context
...@@ -138,6 +140,8 @@ def enable_argscope_for_module(module, log_shape=True): ...@@ -138,6 +140,8 @@ def enable_argscope_for_module(module, log_shape=True):
Args: Args:
log_shape (bool): print input/output shapes of each function. log_shape (bool): print input/output shapes of each function.
""" """
if is_tfv2() and module == tf.layers:
module = tf.compat.v1.layers
for name, obj in getmembers(module): for name, obj in getmembers(module):
if isfunction(obj): if isfunction(obj):
setattr(module, name, enable_argscope_for_function(obj, setattr(module, name, enable_argscope_for_function(obj,
......
...@@ -12,6 +12,7 @@ from tensorflow.python.framework import graph_util ...@@ -12,6 +12,7 @@ from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib from tensorflow.python.tools import optimize_for_inference_lib
from ..compat import is_tfv2, tfv1
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
...@@ -60,7 +61,7 @@ class ModelExporter(object): ...@@ -60,7 +61,7 @@ class ModelExporter(object):
self.config.session_init._setup_graph() self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess) self.config.session_init._run_init(sess)
dtypes = [n.dtype for n in input_tensors] dtypes = [n.dtype for n in input_tensors]
...@@ -88,7 +89,7 @@ class ModelExporter(object): ...@@ -88,7 +89,7 @@ class ModelExporter(object):
logger.info("Output graph written to {}.".format(filename)) logger.info("Output graph written to {}.".format(filename))
def export_serving(self, filename, def export_serving(self, filename,
tags=[tf.saved_model.tag_constants.SERVING], tags=[tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING],
signature_name='prediction_pipeline'): signature_name='prediction_pipeline'):
""" """
Converts a checkpoint and graph to a servable for TensorFlow Serving. Converts a checkpoint and graph to a servable for TensorFlow Serving.
...@@ -121,21 +122,22 @@ class ModelExporter(object): ...@@ -121,21 +122,22 @@ class ModelExporter(object):
self.config.tower_func(*input.get_input_tensors()) self.config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(self.config.input_names) input_tensors = get_tensors_by_names(self.config.input_names)
inputs_signatures = {t.name: tf.saved_model.utils.build_tensor_info(t) for t in input_tensors} saved_model = tfv1.saved_model.utils
inputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in input_tensors}
output_tensors = get_tensors_by_names(self.config.output_names) output_tensors = get_tensors_by_names(self.config.output_names)
outputs_signatures = {t.name: tf.saved_model.utils.build_tensor_info(t) for t in output_tensors} outputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in output_tensors}
self.config.session_init._setup_graph() self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess) self.config.session_init._run_init(sess)
builder = tf.saved_model.builder.SavedModelBuilder(filename) builder = tfv1.saved_model.builder.SavedModelBuilder(filename)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def( prediction_signature = tfv1.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_signatures, inputs=inputs_signatures,
outputs=outputs_signatures, outputs=outputs_signatures,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, tags, sess, tags,
......
...@@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod ...@@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod
import six import six
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from ..utils import logger from ..utils import logger
from .summary import add_moving_summary from .summary import add_moving_summary
from .symbolic_functions import print_stat, rms from .symbolic_functions import print_stat, rms
...@@ -40,11 +41,11 @@ class GradientProcessor(object): ...@@ -40,11 +41,11 @@ class GradientProcessor(object):
# reuse the old name_scope, if process() is called multiple times # reuse the old name_scope, if process() is called multiple times
if self._name_scope is None: if self._name_scope is None:
with tf.name_scope(type(self).__name__) as scope: with tfv1.name_scope(type(self).__name__) as scope:
self._name_scope = scope self._name_scope = scope
return self._process(grads) return self._process(grads)
else: else:
with tf.name_scope(self._name_scope): with tfv1.name_scope(self._name_scope):
return self._process(grads) return self._process(grads)
@abstractmethod @abstractmethod
...@@ -175,7 +176,7 @@ class SummaryGradient(MapGradient): ...@@ -175,7 +176,7 @@ class SummaryGradient(MapGradient):
return grad return grad
if name not in SummaryGradient._summaried_gradient: if name not in SummaryGradient._summaried_gradient:
SummaryGradient._summaried_gradient.add(name) SummaryGradient._summaried_gradient.add(name)
tf.summary.histogram(name + '-grad', grad, collections=self._coll) tfv1.summary.histogram(name + '-grad', grad, collections=self._coll)
add_moving_summary(rms(grad, name=name + '/rms')) add_moving_summary(rms(grad, name=name + '/rms'))
return grad return grad
......
...@@ -20,7 +20,7 @@ class ProxyOptimizer(tfv1.train.Optimizer): ...@@ -20,7 +20,7 @@ class ProxyOptimizer(tfv1.train.Optimizer):
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer` A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
""" """
def __init__(self, opt, name='ProxyOptimizer'): def __init__(self, opt, name='ProxyOptimizer'):
assert isinstance(opt, tf.train.Optimizer), opt assert isinstance(opt, tfv1.train.Optimizer), opt
super(ProxyOptimizer, self).__init__(False, name) super(ProxyOptimizer, self).__init__(False, name)
self._opt = opt self._opt = opt
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import os import os
import numpy as np import numpy as np
import six import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name
......
...@@ -7,6 +7,7 @@ import pprint ...@@ -7,6 +7,7 @@ import pprint
import six import six
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from ..utils import logger from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
...@@ -84,7 +85,7 @@ class SessionUpdate(object): ...@@ -84,7 +85,7 @@ class SessionUpdate(object):
return None return None
if hasattr(value, 'dtype'): if hasattr(value, 'dtype'):
vartype = var.value().dtype vartype = var.dtype
if vartype != value.dtype: if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype) msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype) newtype = upcast(var.dtype.base_dtype, value.dtype)
...@@ -172,7 +173,7 @@ def get_checkpoint_path(model_path): ...@@ -172,7 +173,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path: if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142 model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint': if os.path.basename(model_path) == 'checkpoint':
assert tf.gfile.Exists(model_path), model_path assert tfv1.gfile.Exists(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path)) model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2 # to be consistent with either v1 or v2
...@@ -186,7 +187,7 @@ def get_checkpoint_path(model_path): ...@@ -186,7 +187,7 @@ def get_checkpoint_path(model_path):
logger.info( logger.info(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path)) "Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path model_path = new_path
assert tf.gfile.Exists(model_path) or tf.gfile.Exists(model_path + '.index'), model_path assert tfv1.gfile.Exists(model_path) or tfv1.gfile.Exists(model_path + '.index'), model_path
return model_path return model_path
...@@ -200,7 +201,7 @@ def load_chkpt_vars(model_path): ...@@ -200,7 +201,7 @@ def load_chkpt_vars(model_path):
dict: a name:value dict dict: a name:value dict
""" """
model_path = get_checkpoint_path(model_path) model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path) reader = tfv1.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys() var_names = reader.get_variable_to_shape_map().keys()
result = {} result = {}
for n in var_names: for n in var_names:
......
...@@ -10,7 +10,7 @@ exclude = .git, ...@@ -10,7 +10,7 @@ exclude = .git,
examples, examples,
docs/conf.py docs/conf.py
snippet, snippet,
examples-old, examples_v2,
_test.py, _test.py,
[isort] [isort]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment