Commit b673b24c authored by Yuxin Wu's avatar Yuxin Wu

get_tf_version_number -> get_tf_version_tuple

parent 7eb08df1
...@@ -4,13 +4,11 @@ ...@@ -4,13 +4,11 @@
import abc import abc
import tensorflow as tf import tensorflow as tf
import tensorpack
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.tfutils import ( from tensorpack.tfutils import (
varreplace, summary, get_current_tower_context, optimizer, gradproc) varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
class Model(ModelDesc): class Model(ModelDesc):
......
...@@ -22,7 +22,7 @@ assert six.PY3, "FasterRCNN requires Python 3!" ...@@ -22,7 +22,7 @@ assert six.PY3, "FasterRCNN requires Python 3!"
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_number from tensorpack.tfutils.common import get_tf_version_tuple
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
from coco import COCODetection from coco import COCODetection
...@@ -514,7 +514,7 @@ if __name__ == '__main__': ...@@ -514,7 +514,7 @@ if __name__ == '__main__':
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py", parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+') nargs='+')
if get_tf_version_number() < 1.6: if get_tf_version_tuple() < (1, 6):
# https://github.com/tensorflow/tensorflow/issues/14657 # https://github.com/tensorflow/tensorflow/issues/14657
logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.") logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.")
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Author: Yuxin Wu # Author: Yuxin Wu
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import get_tf_version_number from tensorpack.tfutils import get_tf_version_tuple
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf import tensorflow as tf
...@@ -83,7 +83,7 @@ class Model(DCGAN.Model): ...@@ -83,7 +83,7 @@ class Model(DCGAN.Model):
if __name__ == '__main__': if __name__ == '__main__':
assert get_tf_version_number() >= 1.4 assert get_tf_version_tuple() >= (1, 4)
args = DCGAN.get_args(default_batch=64, default_z_dim=128) args = DCGAN.get_args(default_batch=64, default_z_dim=128)
M = Model(shape=args.final_size, batch=args.batch, z_dim=args.z_dim) M = Model(shape=args.final_size, batch=args.batch, z_dim=args.z_dim)
if args.sample: if args.sample:
......
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils.common import get_tf_version_number
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...@@ -51,13 +50,6 @@ class ModelSaver(Callback): ...@@ -51,13 +50,6 @@ class ModelSaver(Callback):
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
vars = list(set(vars)) vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model') self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver(
var_list=vars,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2)
else:
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=vars, var_list=vars,
max_to_keep=self._max_to_keep, max_to_keep=self._max_to_keep,
......
...@@ -8,7 +8,7 @@ import tensorflow as tf ...@@ -8,7 +8,7 @@ import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils import logger from ..utils import logger
...@@ -67,7 +67,7 @@ class LeastLoadedDeviceSetter(object): ...@@ -67,7 +67,7 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes = [0] * len(self.ps_devices) self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op): def __call__(self, op):
if get_tf_version_number() >= 1.8: if get_tf_version_tuple() >= (1, 8):
from tensorflow.python.training.device_util import canonicalize from tensorflow.python.training.device_util import canonicalize
else: else:
def canonicalize(name): # tensorflow/tensorflow#11484 def canonicalize(name): # tensorflow/tensorflow#11484
......
...@@ -7,7 +7,7 @@ from tensorflow.python.training import moving_averages ...@@ -7,7 +7,7 @@ from tensorflow.python.training import moving_averages
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_tuple
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args from .tflayer import convert_to_tflayer_args
...@@ -128,7 +128,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -128,7 +128,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
xn = tf.squeeze(xn, [1, 2]) xn = tf.squeeze(xn, [1, 2])
else: else:
if ctx.is_training: if ctx.is_training:
assert get_tf_version_number() >= 1.4, \ assert get_tf_version_tuple() >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics is only " \ "Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 " "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower if ctx.is_main_training_tower: # only warn in first tower
......
...@@ -11,7 +11,7 @@ import six ...@@ -11,7 +11,7 @@ import six
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_tuple
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
...@@ -155,9 +155,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -155,9 +155,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if training is None: if training is None:
training = ctx.is_training training = ctx.is_training
training = bool(training) training = bool(training)
TF_version = get_tf_version_number() TF_version = get_tf_version_tuple()
if not training and ctx.is_training: if not training and ctx.is_training:
assert TF_version >= 1.4, \ assert TF_version >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics is only " \ "Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 " "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower if ctx.is_main_training_tower: # only warn in first tower
...@@ -178,7 +178,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -178,7 +178,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
fused=(ndims == 4 and axis in [1, 3]), fused=(ndims == 4 and axis in [1, 3]),
_reuse=tf.get_variable_scope().reuse) _reuse=tf.get_variable_scope().reuse)
if TF_version >= 1.5: if TF_version >= (1, 5):
tf_args['virtual_batch_size'] = virtual_batch_size tf_args['virtual_batch_size'] = virtual_batch_size
else: else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!" assert virtual_batch_size is None, "Feature not supported in this version of TF!"
...@@ -220,7 +220,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -220,7 +220,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis) batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
if sync_statistics == 'nccl': if sync_statistics == 'nccl':
if six.PY3 and TF_version <= 1.9 and ctx.is_main_training_tower: if six.PY3 and TF_version <= (1, 9) and ctx.is_main_training_tower:
logger.warn("A TensorFlow bug will cause cross-GPU BatchNorm to fail. " logger.warn("A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
"Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360") "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360")
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import shape2d, shape4d, get_data_format from ..utils.argtools import shape2d, shape4d, get_data_format
from .tflayer import rename_get_variable, convert_to_tflayer_args from .tflayer import rename_get_variable, convert_to_tflayer_args
...@@ -86,14 +86,14 @@ def Conv2D( ...@@ -86,14 +86,14 @@ def Conv2D(
out_channel = filters out_channel = filters
assert out_channel % split == 0 assert out_channel % split == 0
assert dilation_rate == (1, 1) or get_tf_version_number() >= 1.5, 'TF>=1.5 required for group dilated conv' assert dilation_rate == (1, 1) or get_tf_version_tuple() >= (1, 5), 'TF>=1.5 required for group dilated conv'
kernel_shape = shape2d(kernel_size) kernel_shape = shape2d(kernel_size)
filter_shape = kernel_shape + [in_channel / split, out_channel] filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(strides, data_format=data_format) stride = shape4d(strides, data_format=data_format)
kwargs = dict(data_format=data_format) kwargs = dict(data_format=data_format)
if get_tf_version_number() >= 1.5: if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format) kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)
W = tf.get_variable( W = tf.get_variable(
......
...@@ -6,7 +6,7 @@ import six ...@@ -6,7 +6,7 @@ import six
import functools import functools
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_tuple
from ..tfutils.varreplace import custom_getter_scope from ..tfutils.varreplace import custom_getter_scope
...@@ -112,7 +112,7 @@ def rename_tflayer_get_variable(): ...@@ -112,7 +112,7 @@ def rename_tflayer_get_variable():
def monkeypatch_tf_layers(): def monkeypatch_tf_layers():
if get_tf_version_number() < 1.4: if get_tf_version_tuple() < (1, 4):
if not hasattr(tf.layers, 'Dense'): if not hasattr(tf.layers, 'Dense'):
from tensorflow.python.layers.core import Dense from tensorflow.python.layers.core import Dense
tf.layers.Dense = Dense tf.layers.Dense = Dense
......
...@@ -6,11 +6,10 @@ from abc import abstractmethod, ABCMeta ...@@ -6,11 +6,10 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..tfutils.common import get_tensors_by_names, get_tf_version_number from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..utils.argtools import log_once
from ..utils.utils import execute_only_once from ..utils.utils import execute_only_once
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
...@@ -110,9 +109,7 @@ class OnlinePredictor(PredictorBase): ...@@ -110,9 +109,7 @@ class OnlinePredictor(PredictorBase):
self.input_tensors = input_tensors self.input_tensors = input_tensors
self.output_tensors = output_tensors self.output_tensors = output_tensors
self.sess = sess self.sess = sess
self._use_callable = get_tf_version_number() >= 1.2
if self._use_callable:
if sess is not None: if sess is not None:
self._callable = sess.make_callable( self._callable = sess.make_callable(
fetches=output_tensors, fetches=output_tensors,
...@@ -120,9 +117,6 @@ class OnlinePredictor(PredictorBase): ...@@ -120,9 +117,6 @@ class OnlinePredictor(PredictorBase):
accept_options=self.ACCEPT_OPTIONS) accept_options=self.ACCEPT_OPTIONS)
else: else:
self._callable = None self._callable = None
else:
log_once(
"TF>=1.2 is recommended for better performance of predictor!", 'warn')
def _do_call_old(self, dp): def _do_call_old(self, dp):
feed = dict(zip(self.input_tensors, dp)) feed = dict(zip(self.input_tensors, dp))
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from six.moves import map from six.moves import map
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from ..utils.develop import deprecated
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
...@@ -12,7 +13,6 @@ __all__ = ['get_default_sess_config', ...@@ -12,7 +13,6 @@ __all__ = ['get_default_sess_config',
# 'get_op_tensor_name', # 'get_op_tensor_name',
# 'get_tensors_by_names', # 'get_tensors_by_names',
# 'get_op_or_tensor_by_name', # 'get_op_or_tensor_by_name',
# 'get_tf_version_number',
] ]
...@@ -132,8 +132,13 @@ def get_op_or_tensor_by_name(name): ...@@ -132,8 +132,13 @@ def get_op_or_tensor_by_name(name):
return list(map(f, name)) return list(map(f, name))
@deprecated("You should use get_tf_version_tuple instead due to the existence of TF 1.10")
def get_tf_version_number(): def get_tf_version_number():
return float('.'.join(tf.VERSION.split('.')[:2]))
def get_tf_version_tuple():
""" """
Return a float (for comparison), indicating tensorflow version. Return TensorFlow version as a 2-element tuple (for comparison).
""" """
return float('.'.join(tf.VERSION.split('.')[:2])) return tuple(map(int, tf.VERSION.split('.')[:2]))
...@@ -7,7 +7,7 @@ import functools ...@@ -7,7 +7,7 @@ import functools
from contextlib import contextmanager from contextlib import contextmanager
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from .common import get_tf_version_number from .common import get_tf_version_tuple
__all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope'] __all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope']
...@@ -39,7 +39,7 @@ def auto_reuse_variable_scope(func): ...@@ -39,7 +39,7 @@ def auto_reuse_variable_scope(func):
h = hash((tf.get_default_graph(), scope.name)) h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope)) # print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope: if h in used_scope:
if get_tf_version_number() >= 1.5: if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope(scope, reuse=True, auxiliary_name_scope=False): with tf.variable_scope(scope, reuse=True, auxiliary_name_scope=False):
return func(*args, **kwargs) return func(*args, **kwargs)
else: else:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from .common import get_tf_version_number from .common import get_tf_version_tuple
__all__ = ['freeze_variables', 'remap_variables'] __all__ = ['freeze_variables', 'remap_variables']
...@@ -13,7 +13,7 @@ __all__ = ['freeze_variables', 'remap_variables'] ...@@ -13,7 +13,7 @@ __all__ = ['freeze_variables', 'remap_variables']
@contextmanager @contextmanager
def custom_getter_scope(custom_getter): def custom_getter_scope(custom_getter):
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
if get_tf_version_number() >= 1.5: if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope( with tf.variable_scope(
scope, custom_getter=custom_getter, scope, custom_getter=custom_getter,
auxiliary_name_scope=False): auxiliary_name_scope=False):
......
from case_script import TestPythonScript from case_script import TestPythonScript
from tensorpack.tfutils.common import get_tf_version_number from tensorpack.tfutils.common import get_tf_version_tuple
class InfoGANTest(TestPythonScript): class InfoGANTest(TestPythonScript):
...@@ -10,6 +10,6 @@ class InfoGANTest(TestPythonScript): ...@@ -10,6 +10,6 @@ class InfoGANTest(TestPythonScript):
return '../examples/GAN/InfoGAN-mnist.py' return '../examples/GAN/InfoGAN-mnist.py'
def test(self): def test(self):
if get_tf_version_number() < 1.4: if get_tf_version_tuple() < (1, 4):
return True # requires leaky_relu return True # requires leaky_relu
self.assertSurvive(self.script, args=None) self.assertSurvive(self.script, args=None)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment