Commit b673b24c authored by Yuxin Wu's avatar Yuxin Wu

get_tf_version_number -> get_tf_version_tuple

parent 7eb08df1
......@@ -4,13 +4,11 @@
import abc
import tensorflow as tf
import tensorpack
from tensorpack import ModelDesc
from tensorpack.utils import logger
from tensorpack.tfutils import (
varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
class Model(ModelDesc):
......
......@@ -22,7 +22,7 @@ assert six.PY3, "FasterRCNN requires Python 3!"
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_number
from tensorpack.tfutils.common import get_tf_version_tuple
import tensorpack.utils.viz as tpviz
from coco import COCODetection
......@@ -514,7 +514,7 @@ if __name__ == '__main__':
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+')
if get_tf_version_number() < 1.6:
if get_tf_version_tuple() < (1, 6):
# https://github.com/tensorflow/tensorflow/issues/14657
logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.")
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu
from tensorpack import *
from tensorpack.tfutils import get_tf_version_number
from tensorpack.tfutils import get_tf_version_tuple
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
......@@ -83,7 +83,7 @@ class Model(DCGAN.Model):
if __name__ == '__main__':
assert get_tf_version_number() >= 1.4
assert get_tf_version_tuple() >= (1, 4)
args = DCGAN.get_args(default_batch=64, default_z_dim=128)
M = Model(shape=args.final_size, batch=args.batch, z_dim=args.z_dim)
if args.sample:
......
......@@ -8,7 +8,6 @@ import os
from .base import Callback
from ..utils import logger
from ..tfutils.common import get_tf_version_number
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -51,13 +50,6 @@ class ModelSaver(Callback):
vars.extend(tf.get_collection(key))
vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver(
var_list=vars,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2)
else:
self.saver = tf.train.Saver(
var_list=vars,
max_to_keep=self._max_to_keep,
......
......@@ -8,7 +8,7 @@ import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..tfutils.common import get_tf_version_number
from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import call_only_once
from ..utils import logger
......@@ -67,7 +67,7 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
if get_tf_version_number() >= 1.8:
if get_tf_version_tuple() >= (1, 8):
from tensorflow.python.training.device_util import canonicalize
else:
def canonicalize(name): # tensorflow/tensorflow#11484
......
......@@ -7,7 +7,7 @@ from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from ..tfutils.common import get_tf_version_tuple
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args
......@@ -128,7 +128,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
assert get_tf_version_number() >= 1.4, \
assert get_tf_version_tuple() >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
......
......@@ -11,7 +11,7 @@ import six
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args, rename_get_variable
......@@ -155,9 +155,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if training is None:
training = ctx.is_training
training = bool(training)
TF_version = get_tf_version_number()
TF_version = get_tf_version_tuple()
if not training and ctx.is_training:
assert TF_version >= 1.4, \
assert TF_version >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
......@@ -178,7 +178,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer=gamma_initializer,
fused=(ndims == 4 and axis in [1, 3]),
_reuse=tf.get_variable_scope().reuse)
if TF_version >= 1.5:
if TF_version >= (1, 5):
tf_args['virtual_batch_size'] = virtual_batch_size
else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!"
......@@ -220,7 +220,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
if sync_statistics == 'nccl':
if six.PY3 and TF_version <= 1.9 and ctx.is_main_training_tower:
if six.PY3 and TF_version <= (1, 9) and ctx.is_main_training_tower:
logger.warn("A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
"Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360")
......
......@@ -4,7 +4,7 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from ..tfutils.common import get_tf_version_number
from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import shape2d, shape4d, get_data_format
from .tflayer import rename_get_variable, convert_to_tflayer_args
......@@ -86,14 +86,14 @@ def Conv2D(
out_channel = filters
assert out_channel % split == 0
assert dilation_rate == (1, 1) or get_tf_version_number() >= 1.5, 'TF>=1.5 required for group dilated conv'
assert dilation_rate == (1, 1) or get_tf_version_tuple() >= (1, 5), 'TF>=1.5 required for group dilated conv'
kernel_shape = shape2d(kernel_size)
filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(strides, data_format=data_format)
kwargs = dict(data_format=data_format)
if get_tf_version_number() >= 1.5:
if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)
W = tf.get_variable(
......
......@@ -6,7 +6,7 @@ import six
import functools
from ..utils.argtools import get_data_format
from ..tfutils.common import get_tf_version_number
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.varreplace import custom_getter_scope
......@@ -112,7 +112,7 @@ def rename_tflayer_get_variable():
def monkeypatch_tf_layers():
if get_tf_version_number() < 1.4:
if get_tf_version_tuple() < (1, 4):
if not hasattr(tf.layers, 'Dense'):
from tensorflow.python.layers.core import Dense
tf.layers.Dense = Dense
......
......@@ -6,11 +6,10 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf
import six
from ..tfutils.common import get_tensors_by_names, get_tf_version_number
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated
from ..utils.argtools import log_once
from ..utils.utils import execute_only_once
__all__ = ['PredictorBase', 'AsyncPredictorBase',
......@@ -110,9 +109,7 @@ class OnlinePredictor(PredictorBase):
self.input_tensors = input_tensors
self.output_tensors = output_tensors
self.sess = sess
self._use_callable = get_tf_version_number() >= 1.2
if self._use_callable:
if sess is not None:
self._callable = sess.make_callable(
fetches=output_tensors,
......@@ -120,9 +117,6 @@ class OnlinePredictor(PredictorBase):
accept_options=self.ACCEPT_OPTIONS)
else:
self._callable = None
else:
log_once(
"TF>=1.2 is recommended for better performance of predictor!", 'warn')
def _do_call_old(self, dp):
feed = dict(zip(self.input_tensors, dp))
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
from six.moves import map
from ..utils.argtools import graph_memoized
from ..utils.develop import deprecated
__all__ = ['get_default_sess_config',
'get_global_step_value',
......@@ -12,7 +13,6 @@ __all__ = ['get_default_sess_config',
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
# 'get_tf_version_number',
]
......@@ -132,8 +132,13 @@ def get_op_or_tensor_by_name(name):
return list(map(f, name))
@deprecated("You should use get_tf_version_tuple instead due to the existence of TF 1.10")
def get_tf_version_number():
return float('.'.join(tf.VERSION.split('.')[:2]))
def get_tf_version_tuple():
"""
Return a float (for comparison), indicating tensorflow version.
Return TensorFlow version as a 2-element tuple (for comparison).
"""
return float('.'.join(tf.VERSION.split('.')[:2]))
return tuple(map(int, tf.VERSION.split('.')[:2]))
......@@ -7,7 +7,7 @@ import functools
from contextlib import contextmanager
from ..utils.argtools import graph_memoized
from .common import get_tf_version_number
from .common import get_tf_version_tuple
__all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope']
......@@ -39,7 +39,7 @@ def auto_reuse_variable_scope(func):
h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope:
if get_tf_version_number() >= 1.5:
if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope(scope, reuse=True, auxiliary_name_scope=False):
return func(*args, **kwargs)
else:
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from contextlib import contextmanager
from .common import get_tf_version_number
from .common import get_tf_version_tuple
__all__ = ['freeze_variables', 'remap_variables']
......@@ -13,7 +13,7 @@ __all__ = ['freeze_variables', 'remap_variables']
@contextmanager
def custom_getter_scope(custom_getter):
scope = tf.get_variable_scope()
if get_tf_version_number() >= 1.5:
if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope(
scope, custom_getter=custom_getter,
auxiliary_name_scope=False):
......
from case_script import TestPythonScript
from tensorpack.tfutils.common import get_tf_version_number
from tensorpack.tfutils.common import get_tf_version_tuple
class InfoGANTest(TestPythonScript):
......@@ -10,6 +10,6 @@ class InfoGANTest(TestPythonScript):
return '../examples/GAN/InfoGAN-mnist.py'
def test(self):
if get_tf_version_number() < 1.4:
if get_tf_version_tuple() < (1, 4):
return True # requires leaky_relu
self.assertSurvive(self.script, args=None)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment