Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
505e28eb
Commit
505e28eb
authored
Mar 17, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Backport TensorSpec; tf=tf.compat.v1 in many files.
parent
e4941595
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
39 changed files
with
257 additions
and
104 deletions
+257
-104
docs/tutorial/faq.md
docs/tutorial/faq.md
+9
-3
examples/basics/mnist-convnet.py
examples/basics/mnist-convnet.py
+3
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-1
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+1
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+1
-1
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+1
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-3
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+1
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+1
-1
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/compat/__init__.py
tensorpack/compat/__init__.py
+39
-0
tensorpack/compat/tensor_spec.py
tensorpack/compat/tensor_spec.py
+106
-0
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+5
-6
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+6
-5
tensorpack/libinfo.py
tensorpack/libinfo.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/models/fc.py
tensorpack/models/fc.py
+1
-1
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+1
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+1
-1
tensorpack/models/registry.py
tensorpack/models/registry.py
+2
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+6
-5
tensorpack/tfutils/collection.py
tensorpack/tfutils/collection.py
+2
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+11
-24
tensorpack/tfutils/dependency.py
tensorpack/tfutils/dependency.py
+1
-1
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+1
-1
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+2
-1
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+1
-1
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+14
-13
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+2
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+2
-1
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+3
-2
tensorpack/train/interface.py
tensorpack/train/interface.py
+6
-3
tensorpack/train/tower.py
tensorpack/train/tower.py
+13
-8
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+2
-2
tox.ini
tox.ini
+1
-1
No files found.
docs/tutorial/faq.md
View file @
505e28eb
...
...
@@ -16,9 +16,15 @@ If you think:
Then it is a good time to open an issue.
## How to print/dump intermediate results in training
1.
Learn
`tf.Print`
.
## How to print/dump intermediate results during training
1.
Learn
`tf.Print`
. Most of the times, adding one line in between:
```
python
tensor
=
obtain_a_tensor
()
tensor
=
tf
.
Print
(
tensor
,
[
tf
.
shape
(
tensor
),
tensor
],
tensor
.
name
,
summarize
=
100
)
use_the_tensor
(
tensor
)
```
is sufficient.
2.
Know
[
DumpTensors
](
../modules/callbacks.html#tensorpack.callbacks.DumpTensors
)
,
[
ProcessTensors
](
../modules/callbacks.html#tensorpack.callbacks.ProcessTensors
)
callbacks.
...
...
examples/basics/mnist-convnet.py
View file @
505e28eb
...
...
@@ -21,8 +21,8 @@ class Model(ModelDesc):
"""
Define all the inputs (with type, shape, name) that the graph will need.
"""
return
[
tf
.
placeholder
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
)
,
'input'
),
tf
.
placeholder
(
tf
.
int32
,
(
None
,)
,
'label'
)]
return
[
tf
.
TensorSpec
((
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
tf
.
float32
,
'input'
),
tf
.
TensorSpec
((
None
,),
tf
.
int32
,
'label'
)]
def
build_graph
(
self
,
image
,
label
):
"""This function should build the model which takes the input variables
...
...
@@ -51,7 +51,7 @@ class Model(ModelDesc):
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
# the average cross-entropy loss
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
1
),
tf
.
float32
,
name
=
'correct'
)
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
predictions
=
logits
,
targets
=
label
,
k
=
1
),
tf
.
float32
,
name
=
'correct'
)
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
...
...
tensorpack/callbacks/base.py
View file @
505e28eb
...
...
@@ -4,7 +4,7 @@
from
abc
import
ABCMeta
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_op_or_tensor_by_name
...
...
tensorpack/callbacks/graph.py
View file @
505e28eb
...
...
@@ -6,9 +6,9 @@
import
numpy
as
np
import
os
import
tensorflow
as
tf
from
six.moves
import
zip
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
.base
import
Callback
...
...
tensorpack/callbacks/group.py
View file @
505e28eb
...
...
@@ -6,7 +6,7 @@ import traceback
from
contextlib
import
contextmanager
from
time
import
time
as
timer
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils.utils
import
humanize_time_delta
...
...
tensorpack/callbacks/hooks.py
View file @
505e28eb
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
..
tfutils.common
import
tfv1
from
..
compat
import
tfv1
from
..utils.develop
import
HIDE_DOC
from
.base
import
Callback
...
...
tensorpack/callbacks/inference_runner.py
View file @
505e28eb
...
...
@@ -5,15 +5,14 @@
import
itertools
import
sys
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tqdm
from
six.moves
import
range
from
tensorflow.python.training.monitored_session
import
_HookedSession
as
HookedSession
from
..compat
import
tfv1
as
tf
from
..dataflow.base
import
DataFlow
from
..input_source
import
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.common
import
tfv1
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
.base
import
Callback
...
...
@@ -28,7 +27,7 @@ def _device_from_int(dev):
return
'/gpu:{}'
.
format
(
dev
)
if
dev
>=
0
else
'/cpu:0'
class
InferencerToHook
(
tf
v1
.
train
.
SessionRunHook
):
class
InferencerToHook
(
tf
.
train
.
SessionRunHook
):
def
__init__
(
self
,
inf
,
fetches
):
self
.
_inf
=
inf
self
.
_fetches
=
fetches
...
...
tensorpack/callbacks/monitor.py
View file @
505e28eb
...
...
@@ -12,8 +12,8 @@ import time
from
collections
import
defaultdict
from
datetime
import
datetime
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..libinfo
import
__git_version__
from
..tfutils.summary
import
create_image_summary
,
create_scalar_summary
from
..utils
import
logger
...
...
tensorpack/callbacks/saver.py
View file @
505e28eb
...
...
@@ -4,8 +4,8 @@
import
os
from
datetime
import
datetime
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
.base
import
Callback
...
...
@@ -40,8 +40,8 @@ class ModelSaver(Callback):
if
checkpoint_dir
is
None
:
checkpoint_dir
=
logger
.
get_logger_dir
()
if
checkpoint_dir
is
not
None
:
if
not
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
if
not
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
# v2: tf.io.gfile.isdir
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
# v2: tf.io.gfile.makedirs
self
.
checkpoint_dir
=
checkpoint_dir
def
_setup_graph
(
self
):
...
...
tensorpack/callbacks/steps.py
View file @
505e28eb
...
...
@@ -3,10 +3,10 @@
""" Some common step callbacks. """
import
tensorflow
as
tf
import
tqdm
from
six.moves
import
zip
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
..utils
import
logger
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
...
...
tensorpack/callbacks/summary.py
View file @
505e28eb
...
...
@@ -4,8 +4,8 @@
import
numpy
as
np
from
collections
import
deque
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
...
...
tensorpack/compat/__init__.py
0 → 100644
View file @
505e28eb
#!/usr/bin/env python
import
tensorflow
as
tf
def
backport_tensor_spec
():
if
hasattr
(
tf
,
'TensorSpec'
):
return
tf
.
TensorSpec
try
:
# available since 1.7
from
tensorflow.python.framework.tensor_spec
import
TensorSpec
except
ImportError
:
pass
else
:
tf
.
TensorSpec
=
TensorSpec
return
TensorSpec
from
.tensor_spec
import
TensorSpec
tf
.
TensorSpec
=
TensorSpec
return
TensorSpec
def
is_tfv2
():
try
:
from
tensorflow.python
import
tf2
return
tf2
.
enabled
()
except
Exception
:
return
False
if
is_tfv2
():
tfv1
=
tf
.
compat
.
v1
if
not
hasattr
(
tf
,
'layers'
):
# promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
tf
.
layers
=
tf
.
keras
.
layers
else
:
tfv1
=
tf
tensorpack/compat/tensor_spec.py
0 → 100644
View file @
505e28eb
"""
Copied from tensorflow/python/framework/tensor_spec.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
tensorflow.python.framework
import
common_shapes
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
tensor_shape
class
TensorSpec
(
object
):
"""Describes a tf.Tensor.
Metadata for describing the `tf.Tensor` objects accepted or returned
by some TensorFlow APIs.
"""
__slots__
=
[
"_shape"
,
"_shape_tuple"
,
"_dtype"
,
"_name"
]
def
__init__
(
self
,
shape
,
dtype
=
dtypes
.
float32
,
name
=
None
):
"""Creates a TensorSpec.
Args:
shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
dtype: Value convertible to `tf.DType`. The type of the tensor values.
name: Optional name for the Tensor.
Raises:
TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
not convertible to a `tf.DType`.
"""
self
.
_shape
=
tensor_shape
.
TensorShape
(
shape
)
try
:
self
.
_shape_tuple
=
tuple
(
self
.
shape
.
as_list
())
except
ValueError
:
self
.
_shape_tuple
=
None
self
.
_dtype
=
dtypes
.
as_dtype
(
dtype
)
self
.
_name
=
name
@
classmethod
def
from_spec
(
cls
,
spec
,
name
=
None
):
return
cls
(
spec
.
shape
,
spec
.
dtype
,
name
or
spec
.
name
)
@
classmethod
def
from_tensor
(
cls
,
tensor
,
name
=
None
):
if
isinstance
(
tensor
,
ops
.
EagerTensor
):
return
TensorSpec
(
tensor
.
shape
,
tensor
.
dtype
,
name
)
elif
isinstance
(
tensor
,
ops
.
Tensor
):
return
TensorSpec
(
tensor
.
shape
,
tensor
.
dtype
,
name
or
tensor
.
op
.
name
)
else
:
raise
ValueError
(
"`tensor` should be a tf.Tensor"
)
@
property
def
shape
(
self
):
"""Returns the `TensorShape` that represents the shape of the tensor."""
return
self
.
_shape
@
property
def
dtype
(
self
):
"""Returns the `dtype` of elements in the tensor."""
return
self
.
_dtype
@
property
def
name
(
self
):
"""Returns the (optionally provided) name of the described tensor."""
return
self
.
_name
def
is_compatible_with
(
self
,
spec_or_tensor
):
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
Two tensors are considered compatible if they have the same dtype
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
Args:
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
Returns:
True if spec_or_tensor is compatible with self.
"""
return
(
self
.
_dtype
.
is_compatible_with
(
spec_or_tensor
.
dtype
)
and
self
.
_shape
.
is_compatible_with
(
spec_or_tensor
.
shape
))
def
__repr__
(
self
):
return
"TensorSpec(shape={}, dtype={}, name={})"
.
format
(
self
.
shape
,
repr
(
self
.
dtype
),
repr
(
self
.
name
))
def
__hash__
(
self
):
return
hash
((
self
.
_shape_tuple
,
self
.
dtype
))
def
__eq__
(
self
,
other
):
return
(
self
.
_shape_tuple
==
other
.
_shape_tuple
# pylint: disable=protected-access
and
self
.
dtype
==
other
.
dtype
and
self
.
_name
==
other
.
_name
)
# pylint: disable=protected-access
def
__ne__
(
self
,
other
):
return
not
self
==
other
def
__reduce__
(
self
):
return
TensorSpec
,
(
self
.
_shape
,
self
.
_dtype
,
self
.
_name
)
tensorpack/graph_builder/model_desc.py
View file @
505e28eb
...
...
@@ -7,13 +7,12 @@ import tensorflow as tf
from
..models.regularize
import
regularize_cost_from_collection
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_tuple
from
..utils
import
logger
from
..utils.argtools
import
memoized_method
from
..utils.develop
import
log_deprecated
from
..compat
import
backport_tensor_spec
,
tfv1
if
get_tf_version_tuple
()
>=
(
1
,
7
):
from
tensorflow.python.framework.tensor_spec
import
TensorSpec
TensorSpec
=
backport_tensor_spec
()
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
...
...
@@ -49,8 +48,8 @@ class InputDesc(
Returns:
tf.Tensor:
"""
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tf
.
placeholder
(
with
tf
v1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tf
v1
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
self
.
_register_cached_placeholder
(
ret
)
return
ret
...
...
@@ -63,7 +62,7 @@ class InputDesc(
Returns:
tf.Tensor:
"""
g
=
tf
.
get_default_graph
()
g
=
tf
v1
.
get_default_graph
()
if
g
in
self
.
_cached_placeholder
:
return
self
.
_cached_placeholder
[
g
]
else
:
...
...
tensorpack/input_source/input_source.py
View file @
505e28eb
...
...
@@ -8,6 +8,7 @@ from itertools import chain
import
tensorflow
as
tf
from
six.moves
import
range
,
zip
from
..compat
import
tfv1
from
..callbacks.base
import
Callback
,
CallbackFactory
from
..callbacks.graph
import
RunOp
from
..dataflow
import
DataFlow
,
MapData
,
RepeatedData
...
...
@@ -84,7 +85,7 @@ class FeedInput(InputSource):
dp
=
next
(
self
.
_itr
)
assert
len
(
dp
)
==
len
(
self
.
_placeholders
),
"[FeedInput] datapoints and inputs are of different length!"
feed
=
_make_feeds
(
self
.
_placeholders
,
dp
)
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
return
tf
v1
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
_reset
(
self
):
self
.
_itr
=
self
.
_ds
.
__iter__
()
...
...
@@ -228,9 +229,9 @@ class QueueInput(FeedfreeInput):
"""
self
.
thread
.
pause
()
# pause enqueue
opt
=
tf
.
RunOptions
()
opt
=
tf
v1
.
RunOptions
()
opt
.
timeout_in_ms
=
2000
# 2s
sess
=
tf
.
get_default_session
()
sess
=
tf
v1
.
get_default_session
()
# dequeue until empty
try
:
while
True
:
...
...
@@ -304,7 +305,7 @@ class BatchQueueInput(QueueInput):
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
for
p
in
self
.
input_placehdrs
:
placehdrs_nobatch
.
append
(
tf
.
placeholder
(
placehdrs_nobatch
.
append
(
tf
v1
.
placeholder
(
dtype
=
p
.
dtype
,
shape
=
p
.
get_shape
()
.
as_list
()[
1
:],
name
=
get_op_tensor_name
(
p
.
name
)[
0
]
+
'-nobatch'
))
...
...
@@ -546,7 +547,7 @@ class StagingInput(FeedfreeInput):
unstage_ops
=
self
.
_input
.
_get_unstage_ops
()
unstage_op
=
tf
.
group
(
*
unstage_ops
,
name
=
'unstage_all'
)
self
.
_check_dependency_op
=
unstage_ops
[
0
]
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
fetches
=
tf
v1
.
train
.
SessionRunArgs
(
fetches
=
[
self
.
stage_op
,
unstage_op
])
def
_prefill
(
self
,
sess
):
...
...
tensorpack/libinfo.py
View file @
505e28eb
...
...
@@ -52,7 +52,7 @@ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
try
:
import
tensorflow
as
tf
# noqa
_version
=
tf
.
__version__
.
split
(
'.'
)
assert
int
(
_version
[
0
])
>=
1
and
int
(
_version
[
1
])
>=
3
,
"TF>=1.3 is required!"
assert
(
int
(
_version
[
0
]),
int
(
_version
[
1
]))
>=
(
1
,
3
)
,
"TF>=1.3 is required!"
_HAS_TF
=
True
except
ImportError
:
print
(
"Failed to import tensorflow."
)
...
...
tensorpack/models/batch_norm.py
View file @
505e28eb
...
...
@@ -4,7 +4,7 @@
import
re
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
tensorflow.python.training
import
moving_averages
from
..tfutils.collection
import
backup_collection
,
restore_collection
...
...
tensorpack/models/conv2d.py
View file @
505e28eb
...
...
@@ -2,7 +2,7 @@
# File: conv2d.py
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..tfutils.common
import
get_tf_version_tuple
from
..utils.argtools
import
get_data_format
,
shape2d
,
shape4d
,
log_once
...
...
tensorpack/models/fc.py
View file @
505e28eb
...
...
@@ -3,7 +3,7 @@
import
numpy
as
np
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..tfutils.common
import
get_tf_version_tuple
from
.common
import
VariableHolder
,
layer_register
...
...
tensorpack/models/layer_norm.py
View file @
505e28eb
...
...
@@ -2,7 +2,7 @@
# File: layer_norm.py
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..utils.argtools
import
get_data_format
from
.common
import
VariableHolder
,
layer_register
...
...
tensorpack/models/pool.py
View file @
505e28eb
...
...
@@ -2,7 +2,7 @@
# File: pool.py
import
numpy
as
np
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..utils.argtools
import
get_data_format
,
shape2d
from
..utils.develop
import
log_deprecated
...
...
tensorpack/models/registry.py
View file @
505e28eb
...
...
@@ -8,6 +8,7 @@ from functools import wraps
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.model_utils
import
get_shape_str
from
..utils
import
logger
...
...
@@ -117,7 +118,7 @@ def layer_register(
# del actual_args[k]
if
name
is
not
None
:
# use scope
with
tf
.
variable_scope
(
name
)
as
scope
:
with
tf
v1
.
variable_scope
(
name
)
as
scope
:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
...
...
tensorpack/models/regularize.py
View file @
505e28eb
...
...
@@ -5,6 +5,7 @@
import
re
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
...
...
@@ -60,13 +61,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
# If vars are shared, regularize all of them
# If vars are replicated, only regularize those in the current tower
if
ctx
.
has_own_variables
:
params
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
params
=
ctx
.
get_collection_in_tower
(
tf
v1
.
GraphKeys
.
TRAINABLE_VARIABLES
)
else
:
params
=
tf
.
trainable_variables
()
params
=
tf
v1
.
trainable_variables
()
names
=
[]
with
tf
.
name_scope
(
name
+
'_internals'
):
with
tf
v1
.
name_scope
(
name
+
'_internals'
):
costs
=
[]
for
p
in
params
:
para_name
=
p
.
op
.
name
...
...
@@ -119,9 +120,9 @@ def regularize_cost_from_collection(name='regularize_cost'):
# NOTE: this collection doesn't always grow with towers.
# It only grows with actual variable creation, but not get_variable call.
if
ctx
.
has_own_variables
:
# be careful of the first tower (name='')
losses
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
losses
=
ctx
.
get_collection_in_tower
(
tf
v1
.
GraphKeys
.
REGULARIZATION_LOSSES
)
else
:
losses
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
losses
=
tf
v1
.
get_collection
(
tfv1
.
GraphKeys
.
REGULARIZATION_LOSSES
)
if
len
(
losses
)
>
0
:
logger
.
info
(
"regularize_cost_from_collection() found {} regularizers "
"in REGULARIZATION_LOSSES collection."
.
format
(
len
(
losses
)))
...
...
tensorpack/tfutils/collection.py
View file @
505e28eb
...
...
@@ -5,7 +5,8 @@
from
contextlib
import
contextmanager
from
copy
import
copy
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils.argtools
import
memoized
...
...
tensorpack/tfutils/common.py
View file @
505e28eb
...
...
@@ -5,12 +5,13 @@
import
tensorflow
as
tf
from
six.moves
import
map
from
..compat
import
tfv1
from
..utils.argtools
import
graph_memoized
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_var'
,
'get_tf_version_tuple'
'get_tf_version_tuple'
,
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
...
...
@@ -30,7 +31,7 @@ def get_default_sess_config(mem_fraction=0.99):
Returns:
tf.ConfigProto: the config to use.
"""
conf
=
tf
.
ConfigProto
()
conf
=
tf
v1
.
ConfigProto
()
conf
.
allow_soft_placement
=
True
# conf.log_device_placement = True
...
...
@@ -64,9 +65,9 @@ def get_global_step_var():
Returns:
tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
"""
scope
=
tf
.
VariableScope
(
reuse
=
False
,
name
=
''
)
# the root vs
with
tf
.
variable_scope
(
scope
):
var
=
tf
.
train
.
get_or_create_global_step
()
scope
=
tf
v1
.
VariableScope
(
reuse
=
False
,
name
=
''
)
# the root vs
with
tf
v1
.
variable_scope
(
scope
):
var
=
tf
v1
.
train
.
get_or_create_global_step
()
return
var
...
...
@@ -78,8 +79,8 @@ def get_global_step_value():
Has to be called under a default session.
"""
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
return
tf
v1
.
train
.
global_step
(
tf
v1
.
get_default_session
(),
get_global_step_var
())
...
...
@@ -108,7 +109,7 @@ def get_tensors_by_names(names):
names (list):
"""
ret
=
[]
G
=
tf
.
get_default_graph
()
G
=
tf
v1
.
get_default_graph
()
for
n
in
names
:
opn
,
varn
=
get_op_tensor_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
...
...
@@ -125,7 +126,7 @@ def get_op_or_tensor_by_name(name):
Raises:
KeyError, if the name doesn't exist
"""
G
=
tf
.
get_default_graph
()
G
=
tf
v1
.
get_default_graph
()
def
f
(
n
):
if
len
(
n
)
>=
3
and
n
[
-
2
]
==
':'
:
...
...
@@ -140,7 +141,7 @@ def get_op_or_tensor_by_name(name):
def
gpu_available_in_session
():
sess
=
tf
.
get_default_session
()
sess
=
tf
v1
.
get_default_session
()
for
dev
in
sess
.
list_devices
():
if
dev
.
device_type
.
lower
()
==
'gpu'
:
return
True
...
...
@@ -152,17 +153,3 @@ def get_tf_version_tuple():
Return TensorFlow version as a 2-element tuple (for comparison).
"""
return
tuple
(
map
(
int
,
tf
.
__version__
.
split
(
'.'
)[:
2
]))
def
is_tf2
():
try
:
from
tensorflow.python
import
tf2
return
tf2
.
enabled
()
except
Exception
:
return
False
if
is_tf2
():
tfv1
=
tf
.
compat
.
v1
else
:
tfv1
=
tf
tensorpack/tfutils/dependency.py
View file @
505e28eb
import
tensorflow
as
tf
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
from
..utils.argtools
import
graph_memoized
...
...
@@ -33,6 +32,7 @@ def dependency_of_targets(targets, op):
op
=
op
.
op
assert
isinstance
(
op
,
tf
.
Operation
),
op
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
# alternative implementation can use graph_util.extract_sub_graph
dependent_ops
=
get_backward_walk_ops
(
targets
,
control_inputs
=
True
)
return
op
in
dependent_ops
...
...
tensorpack/tfutils/model_utils.py
View file @
505e28eb
...
...
@@ -2,7 +2,7 @@
# File: model_utils.py
# Author: tensorpack contributors
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
tabulate
import
tabulate
from
termcolor
import
colored
...
...
tensorpack/tfutils/optimizer.py
View file @
505e28eb
...
...
@@ -5,7 +5,8 @@
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_tuple
,
tfv1
from
..tfutils.common
import
get_tf_version_tuple
from
..compat
import
tfv1
from
..utils.develop
import
HIDE_DOC
from
.gradproc
import
FilterNoneGrad
,
GradientProcessor
...
...
tensorpack/tfutils/scope_utils.py
View file @
505e28eb
...
...
@@ -4,8 +4,8 @@
import
functools
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils.argtools
import
graph_memoized
from
.common
import
get_tf_version_tuple
...
...
tensorpack/tfutils/sesscreate.py
View file @
505e28eb
...
...
@@ -2,10 +2,7 @@
# File: sesscreate.py
import
tensorflow
as
tf
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
from
..tfutils.common
import
tfv1
from
..compat
import
tfv1
as
tf
,
is_tfv2
from
..utils
import
logger
from
.common
import
get_default_sess_config
...
...
@@ -20,7 +17,7 @@ A SessionCreator should:
"""
class
NewSessionCreator
(
tf
v1
.
train
.
SessionCreator
):
class
NewSessionCreator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
config
=
None
):
"""
Args:
...
...
@@ -59,12 +56,16 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return
False
def
run
(
op
):
deps
=
get_backward_walk_ops
(
op
,
control_inputs
=
True
)
for
dep_op
in
deps
:
if
blocking_op
(
dep_op
):
logger
.
warn
(
"Initializer '{}' depends on a blocking op '{}'. This initializer is likely to hang!"
.
format
(
op
.
name
,
dep_op
.
name
))
if
not
is_tfv2
():
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
deps
=
get_backward_walk_ops
(
op
,
control_inputs
=
True
)
for
dep_op
in
deps
:
if
blocking_op
(
dep_op
):
logger
.
warn
(
"Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!"
.
format
(
op
.
name
,
dep_op
.
name
))
sess
.
run
(
op
)
run
(
tf
.
global_variables_initializer
())
...
...
@@ -73,7 +74,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return
sess
class
ReuseSessionCreator
(
tf
v1
.
train
.
SessionCreator
):
class
ReuseSessionCreator
(
tf
.
train
.
SessionCreator
):
"""
Returns an existing session.
"""
...
...
@@ -88,7 +89,7 @@ class ReuseSessionCreator(tfv1.train.SessionCreator):
return
self
.
sess
class
SessionCreatorAdapter
(
tf
v1
.
train
.
SessionCreator
):
class
SessionCreatorAdapter
(
tf
.
train
.
SessionCreator
):
"""
Apply a function on the output of a SessionCreator. Can be used to create a debug session.
"""
...
...
tensorpack/tfutils/summary.py
View file @
505e28eb
...
...
@@ -5,10 +5,10 @@
import
re
from
contextlib
import
contextmanager
import
six
import
tensorflow
as
tf
from
six.moves
import
range
from
tensorflow.python.training
import
moving_averages
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils.argtools
import
graph_memoized
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
505e28eb
...
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..utils.develop
import
deprecated
__all__
=
[
'print_stat'
,
'rms'
]
...
...
@@ -30,7 +31,7 @@ def rms(x, name=None):
"""
if
name
is
None
:
name
=
x
.
op
.
name
+
'/rms'
with
tf
.
name_scope
(
None
):
# name already contains the scope
with
tf
v1
.
name_scope
(
None
):
# name already contains the scope
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
...
...
tensorpack/tfutils/tower.py
View file @
505e28eb
...
...
@@ -4,9 +4,10 @@
from
abc
import
ABCMeta
,
abstractmethod
,
abstractproperty
import
six
import
tensorflow
as
tf
from
six.moves
import
zip
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..utils.develop
import
HIDE_DOC
...
...
tensorpack/tfutils/varreplace.py
View file @
505e28eb
...
...
@@ -3,8 +3,8 @@
# Credit: Qinyao He
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
.common
import
get_tf_version_tuple
__all__
=
[
'custom_getter_scope'
,
'freeze_variables'
,
'remap_variables'
]
...
...
tensorpack/train/base.py
View file @
505e28eb
...
...
@@ -8,6 +8,7 @@ import six
import
tensorflow
as
tf
from
six.moves
import
range
from
..compat
import
tfv1
from
..callbacks
import
Callback
,
Callbacks
,
Monitors
,
MonitorBase
from
..callbacks.steps
import
MaintainStepCounter
from
..tfutils
import
get_global_step_value
...
...
@@ -222,7 +223,7 @@ class Trainer(object):
session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit):
"""
assert
isinstance
(
session_creator
,
tf
.
train
.
SessionCreator
),
session_creator
assert
isinstance
(
session_creator
,
tf
v1
.
train
.
SessionCreator
),
session_creator
assert
isinstance
(
session_init
,
SessionInit
),
session_init
session_init
.
_setup_graph
()
...
...
@@ -250,7 +251,7 @@ class Trainer(object):
which can be useful when the training is not done by a single `train_op`.
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
self
.
hooked_sess
=
tf
v1
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
@
call_only_once
...
...
tensorpack/train/interface.py
View file @
505e28eb
# -*- coding: utf-8 -*-
# File: interface.py
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..input_source
import
DummyConstantInput
,
FeedfreeInput
,
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..utils
import
logger
from
..compat
import
is_tfv2
from
.config
import
TrainConfig
from
.tower
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
...
...
@@ -71,6 +71,9 @@ def launch_train_with_config(config, trainer):
launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
"""
if
is_tfv2
():
tfv1
.
disable_eager_execution
()
assert
isinstance
(
trainer
,
SingleCostTrainer
),
trainer
assert
isinstance
(
config
,
TrainConfig
),
config
assert
config
.
model
is
not
None
...
...
@@ -99,7 +102,7 @@ def launch_train_with_config(config, trainer):
def
_check_unused_regularization
():
coll
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
coll
=
tf
v1
.
get_collection
(
tfv1
.
GraphKeys
.
REGULARIZATION_LOSSES
)
unconsumed_reg
=
[]
for
c
in
coll
:
if
len
(
c
.
consumers
())
==
0
:
...
...
tensorpack/train/tower.py
View file @
505e28eb
...
...
@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
,
is_tfv2
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..tfutils.gradproc
import
FilterNoneGrad
...
...
@@ -126,7 +127,7 @@ class TowerTrainer(Trainer):
input
.
setup
(
self
.
inputs_desc
)
vs_name
=
self
.
_vs_name_for_predictor
(
device_id
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
with
tf
v1
.
variable_scope
(
tfv1
.
get_variable_scope
(),
reuse
=
True
),
\
tf
.
device
(
device
),
PredictTowerContext
(
tower_name
,
vs_name
=
vs_name
):
logger
.
info
(
"Building graph for predict tower '{}' on device {} {}..."
.
format
(
...
...
@@ -254,15 +255,19 @@ class SingleCostTrainer(TowerTrainer):
return
None
# this is the tower function, could be called for inference
if
ctx
.
has_own_variables
:
varlist
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
varlist
=
ctx
.
get_collection_in_tower
(
tf
v1
.
GraphKeys
.
TRAINABLE_VARIABLES
)
else
:
varlist
=
tf
.
trainable_variables
()
varlist
=
tf
v1
.
trainable_variables
()
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
self
.
GATE_GRADIENTS
,
colocate_gradients_with_ops
=
self
.
COLOCATE_GRADIENTS_WITH_OPS
,
aggregation_method
=
self
.
AGGREGATION_METHOD
)
if
is_tfv2
()
and
isinstance
(
opt
,
tf
.
optimizers
.
Optimizer
):
grads
=
opt
.
get_gradients
(
cost
,
varlist
)
grads
=
list
(
zip
(
grads
,
varlist
))
else
:
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
self
.
GATE_GRADIENTS
,
colocate_gradients_with_ops
=
self
.
COLOCATE_GRADIENTS_WITH_OPS
,
aggregation_method
=
self
.
AGGREGATION_METHOD
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
...
...
tensorpack/utils/argtools.py
View file @
505e28eb
...
...
@@ -52,7 +52,7 @@ def graph_memoized(func):
"""
# TODO it keeps the graph alive
import
tensorflow
as
tf
from
..compat
import
tfv1
GRAPH_ARG_NAME
=
'__IMPOSSIBLE_NAME_FOR_YOU__'
@
memoized
...
...
@@ -63,7 +63,7 @@ def graph_memoized(func):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
assert
GRAPH_ARG_NAME
not
in
kwargs
,
"No Way!!"
graph
=
tf
.
get_default_graph
()
graph
=
tf
v1
.
get_default_graph
()
kwargs
[
GRAPH_ARG_NAME
]
=
graph
return
func_with_graph_arg
(
*
args
,
**
kwargs
)
return
wrapper
...
...
tox.ini
View file @
505e28eb
...
...
@@ -5,7 +5,7 @@ ignore = E265,E741,E742,E743,W504,W605
exclude
=
.git,
__init__.py,
setup.py,
tensorpack/
train/eager.py
,
tensorpack/
compat/*
,
docs,
examples,
docs/conf.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment