Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b673b24c
Commit
b673b24c
authored
Jul 20, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
get_tf_version_number -> get_tf_version_tuple
parent
7eb08df1
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
45 additions
and
56 deletions
+45
-56
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+0
-2
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+2
-2
examples/GAN/Improved-WGAN.py
examples/GAN/Improved-WGAN.py
+2
-2
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+6
-14
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+2
-2
tensorpack/models/_old_batch_norm.py
tensorpack/models/_old_batch_norm.py
+2
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+5
-5
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+3
-3
tensorpack/models/tflayer.py
tensorpack/models/tflayer.py
+2
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+7
-13
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+8
-3
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+2
-2
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+2
-2
tests/test_infogan.py
tests/test_infogan.py
+2
-2
No files found.
examples/DeepQNetwork/DQNModel.py
View file @
b673b24c
...
...
@@ -4,13 +4,11 @@
import
abc
import
tensorflow
as
tf
import
tensorpack
from
tensorpack
import
ModelDesc
from
tensorpack.utils
import
logger
from
tensorpack.tfutils
import
(
varreplace
,
summary
,
get_current_tower_context
,
optimizer
,
gradproc
)
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
assert
tensorpack
.
tfutils
.
common
.
get_tf_version_number
()
>=
1.2
class
Model
(
ModelDesc
):
...
...
examples/FasterRCNN/train.py
View file @
b673b24c
...
...
@@ -22,7 +22,7 @@ assert six.PY3, "FasterRCNN requires Python 3!"
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils
import
optimizer
from
tensorpack.tfutils.common
import
get_tf_version_
number
from
tensorpack.tfutils.common
import
get_tf_version_
tuple
import
tensorpack.utils.viz
as
tpviz
from
coco
import
COCODetection
...
...
@@ -514,7 +514,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--config'
,
help
=
"A list of KEY=VALUE to overwrite those defined in config.py"
,
nargs
=
'+'
)
if
get_tf_version_
number
()
<
1.6
:
if
get_tf_version_
tuple
()
<
(
1
,
6
)
:
# https://github.com/tensorflow/tensorflow/issues/14657
logger
.
warn
(
"TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky."
)
...
...
examples/GAN/Improved-WGAN.py
View file @
b673b24c
...
...
@@ -4,7 +4,7 @@
# Author: Yuxin Wu
from
tensorpack
import
*
from
tensorpack.tfutils
import
get_tf_version_
number
from
tensorpack.tfutils
import
get_tf_version_
tuple
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
...
...
@@ -83,7 +83,7 @@ class Model(DCGAN.Model):
if
__name__
==
'__main__'
:
assert
get_tf_version_
number
()
>=
1.4
assert
get_tf_version_
tuple
()
>=
(
1
,
4
)
args
=
DCGAN
.
get_args
(
default_batch
=
64
,
default_z_dim
=
128
)
M
=
Model
(
shape
=
args
.
final_size
,
batch
=
args
.
batch
,
z_dim
=
args
.
z_dim
)
if
args
.
sample
:
...
...
tensorpack/callbacks/saver.py
View file @
b673b24c
...
...
@@ -8,7 +8,6 @@ import os
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils.common
import
get_tf_version_number
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
@@ -51,13 +50,6 @@ class ModelSaver(Callback):
vars
.
extend
(
tf
.
get_collection
(
key
))
vars
=
list
(
set
(
vars
))
self
.
path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
'model'
)
if
get_tf_version_number
()
<=
1.1
:
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
vars
,
max_to_keep
=
self
.
_max_to_keep
,
keep_checkpoint_every_n_hours
=
self
.
_keep_every_n_hours
,
write_version
=
tf
.
train
.
SaverDef
.
V2
)
else
:
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
vars
,
max_to_keep
=
self
.
_max_to_keep
,
...
...
tensorpack/graph_builder/utils.py
View file @
b673b24c
...
...
@@ -8,7 +8,7 @@ import tensorflow as tf
from
..tfutils.varreplace
import
custom_getter_scope
from
..tfutils.scope_utils
import
under_name_scope
,
cached_name_scope
from
..tfutils.common
import
get_tf_version_
number
from
..tfutils.common
import
get_tf_version_
tuple
from
..utils.argtools
import
call_only_once
from
..utils
import
logger
...
...
@@ -67,7 +67,7 @@ class LeastLoadedDeviceSetter(object):
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
if
get_tf_version_
number
()
>=
1.8
:
if
get_tf_version_
tuple
()
>=
(
1
,
8
)
:
from
tensorflow.python.training.device_util
import
canonicalize
else
:
def
canonicalize
(
name
):
# tensorflow/tensorflow#11484
...
...
tensorpack/models/_old_batch_norm.py
View file @
b673b24c
...
...
@@ -7,7 +7,7 @@ from tensorflow.python.training import moving_averages
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_
number
from
..tfutils.common
import
get_tf_version_
tuple
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
...
...
@@ -128,7 +128,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ctx
.
is_training
:
assert
get_tf_version_
number
()
>=
1.4
,
\
assert
get_tf_version_
tuple
()
>=
(
1
,
4
)
,
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if
ctx
.
is_main_training_tower
:
# only warn in first tower
...
...
tensorpack/models/batch_norm.py
View file @
b673b24c
...
...
@@ -11,7 +11,7 @@ import six
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_
number
from
..tfutils.common
import
get_tf_version_
tuple
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
...
...
@@ -155,9 +155,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if
training
is
None
:
training
=
ctx
.
is_training
training
=
bool
(
training
)
TF_version
=
get_tf_version_
number
()
TF_version
=
get_tf_version_
tuple
()
if
not
training
and
ctx
.
is_training
:
assert
TF_version
>=
1.4
,
\
assert
TF_version
>=
(
1
,
4
)
,
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if
ctx
.
is_main_training_tower
:
# only warn in first tower
...
...
@@ -178,7 +178,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer
=
gamma_initializer
,
fused
=
(
ndims
==
4
and
axis
in
[
1
,
3
]),
_reuse
=
tf
.
get_variable_scope
()
.
reuse
)
if
TF_version
>=
1.5
:
if
TF_version
>=
(
1
,
5
)
:
tf_args
[
'virtual_batch_size'
]
=
virtual_batch_size
else
:
assert
virtual_batch_size
is
None
,
"Feature not supported in this version of TF!"
...
...
@@ -220,7 +220,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
batch_mean_square
=
tf
.
reduce_mean
(
tf
.
square
(
inputs
),
axis
=
red_axis
)
if
sync_statistics
==
'nccl'
:
if
six
.
PY3
and
TF_version
<=
1.9
and
ctx
.
is_main_training_tower
:
if
six
.
PY3
and
TF_version
<=
(
1
,
9
)
and
ctx
.
is_main_training_tower
:
logger
.
warn
(
"A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
"Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360"
)
...
...
tensorpack/models/conv2d.py
View file @
b673b24c
...
...
@@ -4,7 +4,7 @@
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
..tfutils.common
import
get_tf_version_
number
from
..tfutils.common
import
get_tf_version_
tuple
from
..utils.argtools
import
shape2d
,
shape4d
,
get_data_format
from
.tflayer
import
rename_get_variable
,
convert_to_tflayer_args
...
...
@@ -86,14 +86,14 @@ def Conv2D(
out_channel
=
filters
assert
out_channel
%
split
==
0
assert
dilation_rate
==
(
1
,
1
)
or
get_tf_version_
number
()
>=
1.5
,
'TF>=1.5 required for group dilated conv'
assert
dilation_rate
==
(
1
,
1
)
or
get_tf_version_
tuple
()
>=
(
1
,
5
)
,
'TF>=1.5 required for group dilated conv'
kernel_shape
=
shape2d
(
kernel_size
)
filter_shape
=
kernel_shape
+
[
in_channel
/
split
,
out_channel
]
stride
=
shape4d
(
strides
,
data_format
=
data_format
)
kwargs
=
dict
(
data_format
=
data_format
)
if
get_tf_version_
number
()
>=
1.5
:
if
get_tf_version_
tuple
()
>=
(
1
,
5
)
:
kwargs
[
'dilations'
]
=
shape4d
(
dilation_rate
,
data_format
=
data_format
)
W
=
tf
.
get_variable
(
...
...
tensorpack/models/tflayer.py
View file @
b673b24c
...
...
@@ -6,7 +6,7 @@ import six
import
functools
from
..utils.argtools
import
get_data_format
from
..tfutils.common
import
get_tf_version_
number
from
..tfutils.common
import
get_tf_version_
tuple
from
..tfutils.varreplace
import
custom_getter_scope
...
...
@@ -112,7 +112,7 @@ def rename_tflayer_get_variable():
def
monkeypatch_tf_layers
():
if
get_tf_version_
number
()
<
1.4
:
if
get_tf_version_
tuple
()
<
(
1
,
4
)
:
if
not
hasattr
(
tf
.
layers
,
'Dense'
):
from
tensorflow.python.layers.core
import
Dense
tf
.
layers
.
Dense
=
Dense
...
...
tensorpack/predict/base.py
View file @
b673b24c
...
...
@@ -6,11 +6,10 @@ from abc import abstractmethod, ABCMeta
import
tensorflow
as
tf
import
six
from
..tfutils.common
import
get_tensors_by_names
,
get_tf_version_number
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
PredictTowerContext
from
..input_source
import
PlaceholderInput
from
..utils.develop
import
log_deprecated
from
..utils.argtools
import
log_once
from
..utils.utils
import
execute_only_once
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
...
...
@@ -110,9 +109,7 @@ class OnlinePredictor(PredictorBase):
self
.
input_tensors
=
input_tensors
self
.
output_tensors
=
output_tensors
self
.
sess
=
sess
self
.
_use_callable
=
get_tf_version_number
()
>=
1.2
if
self
.
_use_callable
:
if
sess
is
not
None
:
self
.
_callable
=
sess
.
make_callable
(
fetches
=
output_tensors
,
...
...
@@ -120,9 +117,6 @@ class OnlinePredictor(PredictorBase):
accept_options
=
self
.
ACCEPT_OPTIONS
)
else
:
self
.
_callable
=
None
else
:
log_once
(
"TF>=1.2 is recommended for better performance of predictor!"
,
'warn'
)
def
_do_call_old
(
self
,
dp
):
feed
=
dict
(
zip
(
self
.
input_tensors
,
dp
))
...
...
tensorpack/tfutils/common.py
View file @
b673b24c
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
from
six.moves
import
map
from
..utils.argtools
import
graph_memoized
from
..utils.develop
import
deprecated
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
...
...
@@ -12,7 +13,6 @@ __all__ = ['get_default_sess_config',
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
# 'get_tf_version_number',
]
...
...
@@ -132,8 +132,13 @@ def get_op_or_tensor_by_name(name):
return
list
(
map
(
f
,
name
))
@
deprecated
(
"You should use get_tf_version_tuple instead due to the existence of TF 1.10"
)
def
get_tf_version_number
():
return
float
(
'.'
.
join
(
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
def
get_tf_version_tuple
():
"""
Return
a float (for comparison), indicating tensorflow version
.
Return
TensorFlow version as a 2-element tuple (for comparison)
.
"""
return
float
(
'.'
.
join
(
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
return
tuple
(
map
(
int
,
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
tensorpack/tfutils/scope_utils.py
View file @
b673b24c
...
...
@@ -7,7 +7,7 @@ import functools
from
contextlib
import
contextmanager
from
..utils.argtools
import
graph_memoized
from
.common
import
get_tf_version_
number
from
.common
import
get_tf_version_
tuple
__all__
=
[
'auto_reuse_variable_scope'
,
'cached_name_scope'
,
'under_name_scope'
]
...
...
@@ -39,7 +39,7 @@ def auto_reuse_variable_scope(func):
h
=
hash
((
tf
.
get_default_graph
(),
scope
.
name
))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if
h
in
used_scope
:
if
get_tf_version_
number
()
>=
1.5
:
if
get_tf_version_
tuple
()
>=
(
1
,
5
)
:
with
tf
.
variable_scope
(
scope
,
reuse
=
True
,
auxiliary_name_scope
=
False
):
return
func
(
*
args
,
**
kwargs
)
else
:
...
...
tensorpack/tfutils/varreplace.py
View file @
b673b24c
...
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
.common
import
get_tf_version_
number
from
.common
import
get_tf_version_
tuple
__all__
=
[
'freeze_variables'
,
'remap_variables'
]
...
...
@@ -13,7 +13,7 @@ __all__ = ['freeze_variables', 'remap_variables']
@
contextmanager
def
custom_getter_scope
(
custom_getter
):
scope
=
tf
.
get_variable_scope
()
if
get_tf_version_
number
()
>=
1.5
:
if
get_tf_version_
tuple
()
>=
(
1
,
5
)
:
with
tf
.
variable_scope
(
scope
,
custom_getter
=
custom_getter
,
auxiliary_name_scope
=
False
):
...
...
tests/test_infogan.py
View file @
b673b24c
from
case_script
import
TestPythonScript
from
tensorpack.tfutils.common
import
get_tf_version_
number
from
tensorpack.tfutils.common
import
get_tf_version_
tuple
class
InfoGANTest
(
TestPythonScript
):
...
...
@@ -10,6 +10,6 @@ class InfoGANTest(TestPythonScript):
return
'../examples/GAN/InfoGAN-mnist.py'
def
test
(
self
):
if
get_tf_version_
number
()
<
1.4
:
if
get_tf_version_
tuple
()
<
(
1
,
4
)
:
return
True
# requires leaky_relu
self
.
assertSurvive
(
self
.
script
,
args
=
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment