Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
505e28eb
Commit
505e28eb
authored
Mar 17, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Backport TensorSpec; tf=tf.compat.v1 in many files.
parent
e4941595
Changes
39
Show whitespace changes
Inline
Side-by-side
Showing
39 changed files
with
257 additions
and
104 deletions
+257
-104
docs/tutorial/faq.md
docs/tutorial/faq.md
+9
-3
examples/basics/mnist-convnet.py
examples/basics/mnist-convnet.py
+3
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-1
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+1
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+1
-1
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+1
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-3
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+1
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+1
-1
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/compat/__init__.py
tensorpack/compat/__init__.py
+39
-0
tensorpack/compat/tensor_spec.py
tensorpack/compat/tensor_spec.py
+106
-0
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+5
-6
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+6
-5
tensorpack/libinfo.py
tensorpack/libinfo.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/models/fc.py
tensorpack/models/fc.py
+1
-1
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+1
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+1
-1
tensorpack/models/registry.py
tensorpack/models/registry.py
+2
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+6
-5
tensorpack/tfutils/collection.py
tensorpack/tfutils/collection.py
+2
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+11
-24
tensorpack/tfutils/dependency.py
tensorpack/tfutils/dependency.py
+1
-1
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+1
-1
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+2
-1
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+1
-1
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+14
-13
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+2
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+2
-1
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+3
-2
tensorpack/train/interface.py
tensorpack/train/interface.py
+6
-3
tensorpack/train/tower.py
tensorpack/train/tower.py
+13
-8
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+2
-2
tox.ini
tox.ini
+1
-1
No files found.
docs/tutorial/faq.md
View file @
505e28eb
...
@@ -16,9 +16,15 @@ If you think:
...
@@ -16,9 +16,15 @@ If you think:
Then it is a good time to open an issue.
Then it is a good time to open an issue.
## How to print/dump intermediate results in training
## How to print/dump intermediate results during training
1.
Learn
`tf.Print`
.
1.
Learn
`tf.Print`
. Most of the times, adding one line in between:
```
python
tensor
=
obtain_a_tensor
()
tensor
=
tf
.
Print
(
tensor
,
[
tf
.
shape
(
tensor
),
tensor
],
tensor
.
name
,
summarize
=
100
)
use_the_tensor
(
tensor
)
```
is sufficient.
2.
Know
[
DumpTensors
](
../modules/callbacks.html#tensorpack.callbacks.DumpTensors
)
,
2.
Know
[
DumpTensors
](
../modules/callbacks.html#tensorpack.callbacks.DumpTensors
)
,
[
ProcessTensors
](
../modules/callbacks.html#tensorpack.callbacks.ProcessTensors
)
callbacks.
[
ProcessTensors
](
../modules/callbacks.html#tensorpack.callbacks.ProcessTensors
)
callbacks.
...
...
examples/basics/mnist-convnet.py
View file @
505e28eb
...
@@ -21,8 +21,8 @@ class Model(ModelDesc):
...
@@ -21,8 +21,8 @@ class Model(ModelDesc):
"""
"""
Define all the inputs (with type, shape, name) that the graph will need.
Define all the inputs (with type, shape, name) that the graph will need.
"""
"""
return
[
tf
.
placeholder
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
)
,
'input'
),
return
[
tf
.
TensorSpec
((
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
tf
.
float32
,
'input'
),
tf
.
placeholder
(
tf
.
int32
,
(
None
,)
,
'label'
)]
tf
.
TensorSpec
((
None
,),
tf
.
int32
,
'label'
)]
def
build_graph
(
self
,
image
,
label
):
def
build_graph
(
self
,
image
,
label
):
"""This function should build the model which takes the input variables
"""This function should build the model which takes the input variables
...
@@ -51,7 +51,7 @@ class Model(ModelDesc):
...
@@ -51,7 +51,7 @@ class Model(ModelDesc):
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
# the average cross-entropy loss
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
# the average cross-entropy loss
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
1
),
tf
.
float32
,
name
=
'correct'
)
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
predictions
=
logits
,
targets
=
label
,
k
=
1
),
tf
.
float32
,
name
=
'correct'
)
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
...
...
tensorpack/callbacks/base.py
View file @
505e28eb
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
abc
import
ABCMeta
from
abc
import
ABCMeta
import
six
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_op_or_tensor_by_name
from
..tfutils.common
import
get_op_or_tensor_by_name
...
...
tensorpack/callbacks/graph.py
View file @
505e28eb
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
tensorflow
as
tf
from
six.moves
import
zip
from
six.moves
import
zip
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils
import
logger
from
.base
import
Callback
from
.base
import
Callback
...
...
tensorpack/callbacks/group.py
View file @
505e28eb
...
@@ -6,7 +6,7 @@ import traceback
...
@@ -6,7 +6,7 @@ import traceback
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
time
import
time
as
timer
from
time
import
time
as
timer
import
six
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..utils.utils
import
humanize_time_delta
from
..utils.utils
import
humanize_time_delta
...
...
tensorpack/callbacks/hooks.py
View file @
505e28eb
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..
tfutils.common
import
tfv1
from
..
compat
import
tfv1
from
..utils.develop
import
HIDE_DOC
from
..utils.develop
import
HIDE_DOC
from
.base
import
Callback
from
.base
import
Callback
...
...
tensorpack/callbacks/inference_runner.py
View file @
505e28eb
...
@@ -5,15 +5,14 @@
...
@@ -5,15 +5,14 @@
import
itertools
import
itertools
import
sys
import
sys
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tqdm
import
tqdm
from
six.moves
import
range
from
six.moves
import
range
from
tensorflow.python.training.monitored_session
import
_HookedSession
as
HookedSession
from
tensorflow.python.training.monitored_session
import
_HookedSession
as
HookedSession
from
..compat
import
tfv1
as
tf
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..input_source
import
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..input_source
import
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.common
import
tfv1
from
..utils
import
logger
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.utils
import
get_tqdm_kwargs
from
.base
import
Callback
from
.base
import
Callback
...
@@ -28,7 +27,7 @@ def _device_from_int(dev):
...
@@ -28,7 +27,7 @@ def _device_from_int(dev):
return
'/gpu:{}'
.
format
(
dev
)
if
dev
>=
0
else
'/cpu:0'
return
'/gpu:{}'
.
format
(
dev
)
if
dev
>=
0
else
'/cpu:0'
class
InferencerToHook
(
tf
v1
.
train
.
SessionRunHook
):
class
InferencerToHook
(
tf
.
train
.
SessionRunHook
):
def
__init__
(
self
,
inf
,
fetches
):
def
__init__
(
self
,
inf
,
fetches
):
self
.
_inf
=
inf
self
.
_inf
=
inf
self
.
_fetches
=
fetches
self
.
_fetches
=
fetches
...
...
tensorpack/callbacks/monitor.py
View file @
505e28eb
...
@@ -12,8 +12,8 @@ import time
...
@@ -12,8 +12,8 @@ import time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
datetime
import
datetime
from
datetime
import
datetime
import
six
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..libinfo
import
__git_version__
from
..libinfo
import
__git_version__
from
..tfutils.summary
import
create_image_summary
,
create_scalar_summary
from
..tfutils.summary
import
create_image_summary
,
create_scalar_summary
from
..utils
import
logger
from
..utils
import
logger
...
...
tensorpack/callbacks/saver.py
View file @
505e28eb
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
import
os
import
os
from
datetime
import
datetime
from
datetime
import
datetime
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
.base
import
Callback
from
.base
import
Callback
...
@@ -40,8 +40,8 @@ class ModelSaver(Callback):
...
@@ -40,8 +40,8 @@ class ModelSaver(Callback):
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
checkpoint_dir
=
logger
.
get_logger_dir
()
checkpoint_dir
=
logger
.
get_logger_dir
()
if
checkpoint_dir
is
not
None
:
if
checkpoint_dir
is
not
None
:
if
not
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
if
not
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
# v2: tf.io.gfile.isdir
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
# v2: tf.io.gfile.makedirs
self
.
checkpoint_dir
=
checkpoint_dir
self
.
checkpoint_dir
=
checkpoint_dir
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
...
...
tensorpack/callbacks/steps.py
View file @
505e28eb
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
""" Some common step callbacks. """
""" Some common step callbacks. """
import
tensorflow
as
tf
import
tqdm
import
tqdm
from
six.moves
import
zip
from
six.moves
import
zip
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
...
...
tensorpack/callbacks/summary.py
View file @
505e28eb
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
import
numpy
as
np
import
numpy
as
np
from
collections
import
deque
from
collections
import
deque
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
...
...
tensorpack/compat/__init__.py
0 → 100644
View file @
505e28eb
#!/usr/bin/env python
import
tensorflow
as
tf
def
backport_tensor_spec
():
if
hasattr
(
tf
,
'TensorSpec'
):
return
tf
.
TensorSpec
try
:
# available since 1.7
from
tensorflow.python.framework.tensor_spec
import
TensorSpec
except
ImportError
:
pass
else
:
tf
.
TensorSpec
=
TensorSpec
return
TensorSpec
from
.tensor_spec
import
TensorSpec
tf
.
TensorSpec
=
TensorSpec
return
TensorSpec
def
is_tfv2
():
try
:
from
tensorflow.python
import
tf2
return
tf2
.
enabled
()
except
Exception
:
return
False
if
is_tfv2
():
tfv1
=
tf
.
compat
.
v1
if
not
hasattr
(
tf
,
'layers'
):
# promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
tf
.
layers
=
tf
.
keras
.
layers
else
:
tfv1
=
tf
tensorpack/compat/tensor_spec.py
0 → 100644
View file @
505e28eb
"""
Copied from tensorflow/python/framework/tensor_spec.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
tensorflow.python.framework
import
common_shapes
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
tensor_shape
class
TensorSpec
(
object
):
"""Describes a tf.Tensor.
Metadata for describing the `tf.Tensor` objects accepted or returned
by some TensorFlow APIs.
"""
__slots__
=
[
"_shape"
,
"_shape_tuple"
,
"_dtype"
,
"_name"
]
def
__init__
(
self
,
shape
,
dtype
=
dtypes
.
float32
,
name
=
None
):
"""Creates a TensorSpec.
Args:
shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
dtype: Value convertible to `tf.DType`. The type of the tensor values.
name: Optional name for the Tensor.
Raises:
TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
not convertible to a `tf.DType`.
"""
self
.
_shape
=
tensor_shape
.
TensorShape
(
shape
)
try
:
self
.
_shape_tuple
=
tuple
(
self
.
shape
.
as_list
())
except
ValueError
:
self
.
_shape_tuple
=
None
self
.
_dtype
=
dtypes
.
as_dtype
(
dtype
)
self
.
_name
=
name
@
classmethod
def
from_spec
(
cls
,
spec
,
name
=
None
):
return
cls
(
spec
.
shape
,
spec
.
dtype
,
name
or
spec
.
name
)
@
classmethod
def
from_tensor
(
cls
,
tensor
,
name
=
None
):
if
isinstance
(
tensor
,
ops
.
EagerTensor
):
return
TensorSpec
(
tensor
.
shape
,
tensor
.
dtype
,
name
)
elif
isinstance
(
tensor
,
ops
.
Tensor
):
return
TensorSpec
(
tensor
.
shape
,
tensor
.
dtype
,
name
or
tensor
.
op
.
name
)
else
:
raise
ValueError
(
"`tensor` should be a tf.Tensor"
)
@
property
def
shape
(
self
):
"""Returns the `TensorShape` that represents the shape of the tensor."""
return
self
.
_shape
@
property
def
dtype
(
self
):
"""Returns the `dtype` of elements in the tensor."""
return
self
.
_dtype
@
property
def
name
(
self
):
"""Returns the (optionally provided) name of the described tensor."""
return
self
.
_name
def
is_compatible_with
(
self
,
spec_or_tensor
):
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
Two tensors are considered compatible if they have the same dtype
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
Args:
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
Returns:
True if spec_or_tensor is compatible with self.
"""
return
(
self
.
_dtype
.
is_compatible_with
(
spec_or_tensor
.
dtype
)
and
self
.
_shape
.
is_compatible_with
(
spec_or_tensor
.
shape
))
def
__repr__
(
self
):
return
"TensorSpec(shape={}, dtype={}, name={})"
.
format
(
self
.
shape
,
repr
(
self
.
dtype
),
repr
(
self
.
name
))
def
__hash__
(
self
):
return
hash
((
self
.
_shape_tuple
,
self
.
dtype
))
def
__eq__
(
self
,
other
):
return
(
self
.
_shape_tuple
==
other
.
_shape_tuple
# pylint: disable=protected-access
and
self
.
dtype
==
other
.
dtype
and
self
.
_name
==
other
.
_name
)
# pylint: disable=protected-access
def
__ne__
(
self
,
other
):
return
not
self
==
other
def
__reduce__
(
self
):
return
TensorSpec
,
(
self
.
_shape
,
self
.
_dtype
,
self
.
_name
)
tensorpack/graph_builder/model_desc.py
View file @
505e28eb
...
@@ -7,13 +7,12 @@ import tensorflow as tf
...
@@ -7,13 +7,12 @@ import tensorflow as tf
from
..models.regularize
import
regularize_cost_from_collection
from
..models.regularize
import
regularize_cost_from_collection
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_tuple
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
memoized_method
from
..utils.argtools
import
memoized_method
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..compat
import
backport_tensor_spec
,
tfv1
if
get_tf_version_tuple
()
>=
(
1
,
7
):
TensorSpec
=
backport_tensor_spec
()
from
tensorflow.python.framework.tensor_spec
import
TensorSpec
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
...
@@ -49,8 +48,8 @@ class InputDesc(
...
@@ -49,8 +48,8 @@ class InputDesc(
Returns:
Returns:
tf.Tensor:
tf.Tensor:
"""
"""
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
with
tf
v1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tf
.
placeholder
(
ret
=
tf
v1
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
self
.
_register_cached_placeholder
(
ret
)
self
.
_register_cached_placeholder
(
ret
)
return
ret
return
ret
...
@@ -63,7 +62,7 @@ class InputDesc(
...
@@ -63,7 +62,7 @@ class InputDesc(
Returns:
Returns:
tf.Tensor:
tf.Tensor:
"""
"""
g
=
tf
.
get_default_graph
()
g
=
tf
v1
.
get_default_graph
()
if
g
in
self
.
_cached_placeholder
:
if
g
in
self
.
_cached_placeholder
:
return
self
.
_cached_placeholder
[
g
]
return
self
.
_cached_placeholder
[
g
]
else
:
else
:
...
...
tensorpack/input_source/input_source.py
View file @
505e28eb
...
@@ -8,6 +8,7 @@ from itertools import chain
...
@@ -8,6 +8,7 @@ from itertools import chain
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
range
,
zip
from
six.moves
import
range
,
zip
from
..compat
import
tfv1
from
..callbacks.base
import
Callback
,
CallbackFactory
from
..callbacks.base
import
Callback
,
CallbackFactory
from
..callbacks.graph
import
RunOp
from
..callbacks.graph
import
RunOp
from
..dataflow
import
DataFlow
,
MapData
,
RepeatedData
from
..dataflow
import
DataFlow
,
MapData
,
RepeatedData
...
@@ -84,7 +85,7 @@ class FeedInput(InputSource):
...
@@ -84,7 +85,7 @@ class FeedInput(InputSource):
dp
=
next
(
self
.
_itr
)
dp
=
next
(
self
.
_itr
)
assert
len
(
dp
)
==
len
(
self
.
_placeholders
),
"[FeedInput] datapoints and inputs are of different length!"
assert
len
(
dp
)
==
len
(
self
.
_placeholders
),
"[FeedInput] datapoints and inputs are of different length!"
feed
=
_make_feeds
(
self
.
_placeholders
,
dp
)
feed
=
_make_feeds
(
self
.
_placeholders
,
dp
)
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
return
tf
v1
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
_reset
(
self
):
def
_reset
(
self
):
self
.
_itr
=
self
.
_ds
.
__iter__
()
self
.
_itr
=
self
.
_ds
.
__iter__
()
...
@@ -228,9 +229,9 @@ class QueueInput(FeedfreeInput):
...
@@ -228,9 +229,9 @@ class QueueInput(FeedfreeInput):
"""
"""
self
.
thread
.
pause
()
# pause enqueue
self
.
thread
.
pause
()
# pause enqueue
opt
=
tf
.
RunOptions
()
opt
=
tf
v1
.
RunOptions
()
opt
.
timeout_in_ms
=
2000
# 2s
opt
.
timeout_in_ms
=
2000
# 2s
sess
=
tf
.
get_default_session
()
sess
=
tf
v1
.
get_default_session
()
# dequeue until empty
# dequeue until empty
try
:
try
:
while
True
:
while
True
:
...
@@ -304,7 +305,7 @@ class BatchQueueInput(QueueInput):
...
@@ -304,7 +305,7 @@ class BatchQueueInput(QueueInput):
# prepare placeholders without the first dimension
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
placehdrs_nobatch
=
[]
for
p
in
self
.
input_placehdrs
:
for
p
in
self
.
input_placehdrs
:
placehdrs_nobatch
.
append
(
tf
.
placeholder
(
placehdrs_nobatch
.
append
(
tf
v1
.
placeholder
(
dtype
=
p
.
dtype
,
shape
=
p
.
get_shape
()
.
as_list
()[
1
:],
dtype
=
p
.
dtype
,
shape
=
p
.
get_shape
()
.
as_list
()[
1
:],
name
=
get_op_tensor_name
(
p
.
name
)[
0
]
+
'-nobatch'
))
name
=
get_op_tensor_name
(
p
.
name
)[
0
]
+
'-nobatch'
))
...
@@ -546,7 +547,7 @@ class StagingInput(FeedfreeInput):
...
@@ -546,7 +547,7 @@ class StagingInput(FeedfreeInput):
unstage_ops
=
self
.
_input
.
_get_unstage_ops
()
unstage_ops
=
self
.
_input
.
_get_unstage_ops
()
unstage_op
=
tf
.
group
(
*
unstage_ops
,
name
=
'unstage_all'
)
unstage_op
=
tf
.
group
(
*
unstage_ops
,
name
=
'unstage_all'
)
self
.
_check_dependency_op
=
unstage_ops
[
0
]
self
.
_check_dependency_op
=
unstage_ops
[
0
]
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
fetches
=
tf
v1
.
train
.
SessionRunArgs
(
fetches
=
[
self
.
stage_op
,
unstage_op
])
fetches
=
[
self
.
stage_op
,
unstage_op
])
def
_prefill
(
self
,
sess
):
def
_prefill
(
self
,
sess
):
...
...
tensorpack/libinfo.py
View file @
505e28eb
...
@@ -52,7 +52,7 @@ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
...
@@ -52,7 +52,7 @@ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
try
:
try
:
import
tensorflow
as
tf
# noqa
import
tensorflow
as
tf
# noqa
_version
=
tf
.
__version__
.
split
(
'.'
)
_version
=
tf
.
__version__
.
split
(
'.'
)
assert
int
(
_version
[
0
])
>=
1
and
int
(
_version
[
1
])
>=
3
,
"TF>=1.3 is required!"
assert
(
int
(
_version
[
0
]),
int
(
_version
[
1
]))
>=
(
1
,
3
)
,
"TF>=1.3 is required!"
_HAS_TF
=
True
_HAS_TF
=
True
except
ImportError
:
except
ImportError
:
print
(
"Failed to import tensorflow."
)
print
(
"Failed to import tensorflow."
)
...
...
tensorpack/models/batch_norm.py
View file @
505e28eb
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
re
import
re
import
six
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.training
import
moving_averages
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.collection
import
backup_collection
,
restore_collection
...
...
tensorpack/models/conv2d.py
View file @
505e28eb
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# File: conv2d.py
# File: conv2d.py
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.common
import
get_tf_version_tuple
from
..utils.argtools
import
get_data_format
,
shape2d
,
shape4d
,
log_once
from
..utils.argtools
import
get_data_format
,
shape2d
,
shape4d
,
log_once
...
...
tensorpack/models/fc.py
View file @
505e28eb
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.common
import
get_tf_version_tuple
from
.common
import
VariableHolder
,
layer_register
from
.common
import
VariableHolder
,
layer_register
...
...
tensorpack/models/layer_norm.py
View file @
505e28eb
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# File: layer_norm.py
# File: layer_norm.py
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..utils.argtools
import
get_data_format
from
..utils.argtools
import
get_data_format
from
.common
import
VariableHolder
,
layer_register
from
.common
import
VariableHolder
,
layer_register
...
...
tensorpack/models/pool.py
View file @
505e28eb
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# File: pool.py
# File: pool.py
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..utils.argtools
import
get_data_format
,
shape2d
from
..utils.argtools
import
get_data_format
,
shape2d
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
...
...
tensorpack/models/registry.py
View file @
505e28eb
...
@@ -8,6 +8,7 @@ from functools import wraps
...
@@ -8,6 +8,7 @@ from functools import wraps
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.model_utils
import
get_shape_str
from
..tfutils.model_utils
import
get_shape_str
from
..utils
import
logger
from
..utils
import
logger
...
@@ -117,7 +118,7 @@ def layer_register(
...
@@ -117,7 +118,7 @@ def layer_register(
# del actual_args[k]
# del actual_args[k]
if
name
is
not
None
:
# use scope
if
name
is
not
None
:
# use scope
with
tf
.
variable_scope
(
name
)
as
scope
:
with
tf
v1
.
variable_scope
(
name
)
as
scope
:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
scope_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
scope
.
name
)
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
do_log_shape
=
log_shape
and
scope_name
not
in
_LAYER_LOGGED
...
...
tensorpack/models/regularize.py
View file @
505e28eb
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
re
import
re
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
...
@@ -60,13 +61,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
...
@@ -60,13 +61,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
# If vars are shared, regularize all of them
# If vars are shared, regularize all of them
# If vars are replicated, only regularize those in the current tower
# If vars are replicated, only regularize those in the current tower
if
ctx
.
has_own_variables
:
if
ctx
.
has_own_variables
:
params
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
params
=
ctx
.
get_collection_in_tower
(
tf
v1
.
GraphKeys
.
TRAINABLE_VARIABLES
)
else
:
else
:
params
=
tf
.
trainable_variables
()
params
=
tf
v1
.
trainable_variables
()
names
=
[]
names
=
[]
with
tf
.
name_scope
(
name
+
'_internals'
):
with
tf
v1
.
name_scope
(
name
+
'_internals'
):
costs
=
[]
costs
=
[]
for
p
in
params
:
for
p
in
params
:
para_name
=
p
.
op
.
name
para_name
=
p
.
op
.
name
...
@@ -119,9 +120,9 @@ def regularize_cost_from_collection(name='regularize_cost'):
...
@@ -119,9 +120,9 @@ def regularize_cost_from_collection(name='regularize_cost'):
# NOTE: this collection doesn't always grow with towers.
# NOTE: this collection doesn't always grow with towers.
# It only grows with actual variable creation, but not get_variable call.
# It only grows with actual variable creation, but not get_variable call.
if
ctx
.
has_own_variables
:
# be careful of the first tower (name='')
if
ctx
.
has_own_variables
:
# be careful of the first tower (name='')
losses
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
losses
=
ctx
.
get_collection_in_tower
(
tf
v1
.
GraphKeys
.
REGULARIZATION_LOSSES
)
else
:
else
:
losses
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
losses
=
tf
v1
.
get_collection
(
tfv1
.
GraphKeys
.
REGULARIZATION_LOSSES
)
if
len
(
losses
)
>
0
:
if
len
(
losses
)
>
0
:
logger
.
info
(
"regularize_cost_from_collection() found {} regularizers "
logger
.
info
(
"regularize_cost_from_collection() found {} regularizers "
"in REGULARIZATION_LOSSES collection."
.
format
(
len
(
losses
)))
"in REGULARIZATION_LOSSES collection."
.
format
(
len
(
losses
)))
...
...
tensorpack/tfutils/collection.py
View file @
505e28eb
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
copy
import
copy
from
copy
import
copy
import
six
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
...
...
tensorpack/tfutils/common.py
View file @
505e28eb
...
@@ -5,12 +5,13 @@
...
@@ -5,12 +5,13 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
map
from
six.moves
import
map
from
..compat
import
tfv1
from
..utils.argtools
import
graph_memoized
from
..utils.argtools
import
graph_memoized
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_value'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_tf_version_tuple'
'get_tf_version_tuple'
,
# 'get_op_tensor_name',
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
# 'get_op_or_tensor_by_name',
...
@@ -30,7 +31,7 @@ def get_default_sess_config(mem_fraction=0.99):
...
@@ -30,7 +31,7 @@ def get_default_sess_config(mem_fraction=0.99):
Returns:
Returns:
tf.ConfigProto: the config to use.
tf.ConfigProto: the config to use.
"""
"""
conf
=
tf
.
ConfigProto
()
conf
=
tf
v1
.
ConfigProto
()
conf
.
allow_soft_placement
=
True
conf
.
allow_soft_placement
=
True
# conf.log_device_placement = True
# conf.log_device_placement = True
...
@@ -64,9 +65,9 @@ def get_global_step_var():
...
@@ -64,9 +65,9 @@ def get_global_step_var():
Returns:
Returns:
tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
"""
"""
scope
=
tf
.
VariableScope
(
reuse
=
False
,
name
=
''
)
# the root vs
scope
=
tf
v1
.
VariableScope
(
reuse
=
False
,
name
=
''
)
# the root vs
with
tf
.
variable_scope
(
scope
):
with
tf
v1
.
variable_scope
(
scope
):
var
=
tf
.
train
.
get_or_create_global_step
()
var
=
tf
v1
.
train
.
get_or_create_global_step
()
return
var
return
var
...
@@ -78,8 +79,8 @@ def get_global_step_value():
...
@@ -78,8 +79,8 @@ def get_global_step_value():
Has to be called under a default session.
Has to be called under a default session.
"""
"""
return
tf
.
train
.
global_step
(
return
tf
v1
.
train
.
global_step
(
tf
.
get_default_session
(),
tf
v1
.
get_default_session
(),
get_global_step_var
())
get_global_step_var
())
...
@@ -108,7 +109,7 @@ def get_tensors_by_names(names):
...
@@ -108,7 +109,7 @@ def get_tensors_by_names(names):
names (list):
names (list):
"""
"""
ret
=
[]
ret
=
[]
G
=
tf
.
get_default_graph
()
G
=
tf
v1
.
get_default_graph
()
for
n
in
names
:
for
n
in
names
:
opn
,
varn
=
get_op_tensor_name
(
n
)
opn
,
varn
=
get_op_tensor_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
...
@@ -125,7 +126,7 @@ def get_op_or_tensor_by_name(name):
...
@@ -125,7 +126,7 @@ def get_op_or_tensor_by_name(name):
Raises:
Raises:
KeyError, if the name doesn't exist
KeyError, if the name doesn't exist
"""
"""
G
=
tf
.
get_default_graph
()
G
=
tf
v1
.
get_default_graph
()
def
f
(
n
):
def
f
(
n
):
if
len
(
n
)
>=
3
and
n
[
-
2
]
==
':'
:
if
len
(
n
)
>=
3
and
n
[
-
2
]
==
':'
:
...
@@ -140,7 +141,7 @@ def get_op_or_tensor_by_name(name):
...
@@ -140,7 +141,7 @@ def get_op_or_tensor_by_name(name):
def
gpu_available_in_session
():
def
gpu_available_in_session
():
sess
=
tf
.
get_default_session
()
sess
=
tf
v1
.
get_default_session
()
for
dev
in
sess
.
list_devices
():
for
dev
in
sess
.
list_devices
():
if
dev
.
device_type
.
lower
()
==
'gpu'
:
if
dev
.
device_type
.
lower
()
==
'gpu'
:
return
True
return
True
...
@@ -152,17 +153,3 @@ def get_tf_version_tuple():
...
@@ -152,17 +153,3 @@ def get_tf_version_tuple():
Return TensorFlow version as a 2-element tuple (for comparison).
Return TensorFlow version as a 2-element tuple (for comparison).
"""
"""
return
tuple
(
map
(
int
,
tf
.
__version__
.
split
(
'.'
)[:
2
]))
return
tuple
(
map
(
int
,
tf
.
__version__
.
split
(
'.'
)[:
2
]))
def
is_tf2
():
try
:
from
tensorflow.python
import
tf2
return
tf2
.
enabled
()
except
Exception
:
return
False
if
is_tf2
():
tfv1
=
tf
.
compat
.
v1
else
:
tfv1
=
tf
tensorpack/tfutils/dependency.py
View file @
505e28eb
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
from
..utils.argtools
import
graph_memoized
from
..utils.argtools
import
graph_memoized
...
@@ -33,6 +32,7 @@ def dependency_of_targets(targets, op):
...
@@ -33,6 +32,7 @@ def dependency_of_targets(targets, op):
op
=
op
.
op
op
=
op
.
op
assert
isinstance
(
op
,
tf
.
Operation
),
op
assert
isinstance
(
op
,
tf
.
Operation
),
op
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
# alternative implementation can use graph_util.extract_sub_graph
# alternative implementation can use graph_util.extract_sub_graph
dependent_ops
=
get_backward_walk_ops
(
targets
,
control_inputs
=
True
)
dependent_ops
=
get_backward_walk_ops
(
targets
,
control_inputs
=
True
)
return
op
in
dependent_ops
return
op
in
dependent_ops
...
...
tensorpack/tfutils/model_utils.py
View file @
505e28eb
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# File: model_utils.py
# File: model_utils.py
# Author: tensorpack contributors
# Author: tensorpack contributors
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
tabulate
import
tabulate
from
tabulate
import
tabulate
from
termcolor
import
colored
from
termcolor
import
colored
...
...
tensorpack/tfutils/optimizer.py
View file @
505e28eb
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_tuple
,
tfv1
from
..tfutils.common
import
get_tf_version_tuple
from
..compat
import
tfv1
from
..utils.develop
import
HIDE_DOC
from
..utils.develop
import
HIDE_DOC
from
.gradproc
import
FilterNoneGrad
,
GradientProcessor
from
.gradproc
import
FilterNoneGrad
,
GradientProcessor
...
...
tensorpack/tfutils/scope_utils.py
View file @
505e28eb
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
import
functools
import
functools
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
..utils.argtools
import
graph_memoized
from
..utils.argtools
import
graph_memoized
from
.common
import
get_tf_version_tuple
from
.common
import
get_tf_version_tuple
...
...
tensorpack/tfutils/sesscreate.py
View file @
505e28eb
...
@@ -2,10 +2,7 @@
...
@@ -2,10 +2,7 @@
# File: sesscreate.py
# File: sesscreate.py
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
,
is_tfv2
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
from
..tfutils.common
import
tfv1
from
..utils
import
logger
from
..utils
import
logger
from
.common
import
get_default_sess_config
from
.common
import
get_default_sess_config
...
@@ -20,7 +17,7 @@ A SessionCreator should:
...
@@ -20,7 +17,7 @@ A SessionCreator should:
"""
"""
class
NewSessionCreator
(
tf
v1
.
train
.
SessionCreator
):
class
NewSessionCreator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
config
=
None
):
def
__init__
(
self
,
target
=
''
,
config
=
None
):
"""
"""
Args:
Args:
...
@@ -59,11 +56,15 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
...
@@ -59,11 +56,15 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return
False
return
False
def
run
(
op
):
def
run
(
op
):
if
not
is_tfv2
():
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
deps
=
get_backward_walk_ops
(
op
,
control_inputs
=
True
)
deps
=
get_backward_walk_ops
(
op
,
control_inputs
=
True
)
for
dep_op
in
deps
:
for
dep_op
in
deps
:
if
blocking_op
(
dep_op
):
if
blocking_op
(
dep_op
):
logger
.
warn
(
logger
.
warn
(
"Initializer '{}' depends on a blocking op '{}'. This initializer is likely to hang!"
.
format
(
"Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!"
.
format
(
op
.
name
,
dep_op
.
name
))
op
.
name
,
dep_op
.
name
))
sess
.
run
(
op
)
sess
.
run
(
op
)
...
@@ -73,7 +74,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
...
@@ -73,7 +74,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return
sess
return
sess
class
ReuseSessionCreator
(
tf
v1
.
train
.
SessionCreator
):
class
ReuseSessionCreator
(
tf
.
train
.
SessionCreator
):
"""
"""
Returns an existing session.
Returns an existing session.
"""
"""
...
@@ -88,7 +89,7 @@ class ReuseSessionCreator(tfv1.train.SessionCreator):
...
@@ -88,7 +89,7 @@ class ReuseSessionCreator(tfv1.train.SessionCreator):
return
self
.
sess
return
self
.
sess
class
SessionCreatorAdapter
(
tf
v1
.
train
.
SessionCreator
):
class
SessionCreatorAdapter
(
tf
.
train
.
SessionCreator
):
"""
"""
Apply a function on the output of a SessionCreator. Can be used to create a debug session.
Apply a function on the output of a SessionCreator. Can be used to create a debug session.
"""
"""
...
...
tensorpack/tfutils/summary.py
View file @
505e28eb
...
@@ -5,10 +5,10 @@
...
@@ -5,10 +5,10 @@
import
re
import
re
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
six
import
six
import
tensorflow
as
tf
from
six.moves
import
range
from
six.moves
import
range
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.training
import
moving_averages
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
graph_memoized
from
..utils.argtools
import
graph_memoized
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
505e28eb
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..utils.develop
import
deprecated
from
..utils.develop
import
deprecated
__all__
=
[
'print_stat'
,
'rms'
]
__all__
=
[
'print_stat'
,
'rms'
]
...
@@ -30,7 +31,7 @@ def rms(x, name=None):
...
@@ -30,7 +31,7 @@ def rms(x, name=None):
"""
"""
if
name
is
None
:
if
name
is
None
:
name
=
x
.
op
.
name
+
'/rms'
name
=
x
.
op
.
name
+
'/rms'
with
tf
.
name_scope
(
None
):
# name already contains the scope
with
tf
v1
.
name_scope
(
None
):
# name already contains the scope
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
...
...
tensorpack/tfutils/tower.py
View file @
505e28eb
...
@@ -4,9 +4,10 @@
...
@@ -4,9 +4,10 @@
from
abc
import
ABCMeta
,
abstractmethod
,
abstractproperty
from
abc
import
ABCMeta
,
abstractmethod
,
abstractproperty
import
six
import
six
import
tensorflow
as
tf
from
six.moves
import
zip
from
six.moves
import
zip
from
..compat
import
tfv1
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..utils.argtools
import
call_only_once
from
..utils.develop
import
HIDE_DOC
from
..utils.develop
import
HIDE_DOC
...
...
tensorpack/tfutils/varreplace.py
View file @
505e28eb
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
# Credit: Qinyao He
# Credit: Qinyao He
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..compat
import
tfv1
as
tf
from
.common
import
get_tf_version_tuple
from
.common
import
get_tf_version_tuple
__all__
=
[
'custom_getter_scope'
,
'freeze_variables'
,
'remap_variables'
]
__all__
=
[
'custom_getter_scope'
,
'freeze_variables'
,
'remap_variables'
]
...
...
tensorpack/train/base.py
View file @
505e28eb
...
@@ -8,6 +8,7 @@ import six
...
@@ -8,6 +8,7 @@ import six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
range
from
six.moves
import
range
from
..compat
import
tfv1
from
..callbacks
import
Callback
,
Callbacks
,
Monitors
,
MonitorBase
from
..callbacks
import
Callback
,
Callbacks
,
Monitors
,
MonitorBase
from
..callbacks.steps
import
MaintainStepCounter
from
..callbacks.steps
import
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
...
@@ -222,7 +223,7 @@ class Trainer(object):
...
@@ -222,7 +223,7 @@ class Trainer(object):
session_creator (tf.train.SessionCreator):
session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit):
session_init (sessinit.SessionInit):
"""
"""
assert
isinstance
(
session_creator
,
tf
.
train
.
SessionCreator
),
session_creator
assert
isinstance
(
session_creator
,
tf
v1
.
train
.
SessionCreator
),
session_creator
assert
isinstance
(
session_init
,
SessionInit
),
session_init
assert
isinstance
(
session_init
,
SessionInit
),
session_init
session_init
.
_setup_graph
()
session_init
.
_setup_graph
()
...
@@ -250,7 +251,7 @@ class Trainer(object):
...
@@ -250,7 +251,7 @@ class Trainer(object):
which can be useful when the training is not done by a single `train_op`.
which can be useful when the training is not done by a single `train_op`.
"""
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
self
.
hooked_sess
=
tf
v1
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
@
call_only_once
@
call_only_once
...
...
tensorpack/train/interface.py
View file @
505e28eb
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: interface.py
# File: interface.py
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..input_source
import
DummyConstantInput
,
FeedfreeInput
,
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..input_source
import
DummyConstantInput
,
FeedfreeInput
,
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..utils
import
logger
from
..utils
import
logger
from
..compat
import
is_tfv2
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
.tower
import
SingleCostTrainer
from
.tower
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
from
.trainers
import
SimpleTrainer
...
@@ -71,6 +71,9 @@ def launch_train_with_config(config, trainer):
...
@@ -71,6 +71,9 @@ def launch_train_with_config(config, trainer):
launch_train_with_config(
launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
"""
"""
if
is_tfv2
():
tfv1
.
disable_eager_execution
()
assert
isinstance
(
trainer
,
SingleCostTrainer
),
trainer
assert
isinstance
(
trainer
,
SingleCostTrainer
),
trainer
assert
isinstance
(
config
,
TrainConfig
),
config
assert
isinstance
(
config
,
TrainConfig
),
config
assert
config
.
model
is
not
None
assert
config
.
model
is
not
None
...
@@ -99,7 +102,7 @@ def launch_train_with_config(config, trainer):
...
@@ -99,7 +102,7 @@ def launch_train_with_config(config, trainer):
def
_check_unused_regularization
():
def
_check_unused_regularization
():
coll
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
coll
=
tf
v1
.
get_collection
(
tfv1
.
GraphKeys
.
REGULARIZATION_LOSSES
)
unconsumed_reg
=
[]
unconsumed_reg
=
[]
for
c
in
coll
:
for
c
in
coll
:
if
len
(
c
.
consumers
())
==
0
:
if
len
(
c
.
consumers
())
==
0
:
...
...
tensorpack/train/tower.py
View file @
505e28eb
...
@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
...
@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..compat
import
tfv1
,
is_tfv2
from
..input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..predict.base
import
OnlinePredictor
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.gradproc
import
FilterNoneGrad
...
@@ -126,7 +127,7 @@ class TowerTrainer(Trainer):
...
@@ -126,7 +127,7 @@ class TowerTrainer(Trainer):
input
.
setup
(
self
.
inputs_desc
)
input
.
setup
(
self
.
inputs_desc
)
vs_name
=
self
.
_vs_name_for_predictor
(
device_id
)
vs_name
=
self
.
_vs_name_for_predictor
(
device_id
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
with
tf
v1
.
variable_scope
(
tfv1
.
get_variable_scope
(),
reuse
=
True
),
\
tf
.
device
(
device
),
PredictTowerContext
(
tf
.
device
(
device
),
PredictTowerContext
(
tower_name
,
vs_name
=
vs_name
):
tower_name
,
vs_name
=
vs_name
):
logger
.
info
(
"Building graph for predict tower '{}' on device {} {}..."
.
format
(
logger
.
info
(
"Building graph for predict tower '{}' on device {} {}..."
.
format
(
...
@@ -254,10 +255,14 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -254,10 +255,14 @@ class SingleCostTrainer(TowerTrainer):
return
None
# this is the tower function, could be called for inference
return
None
# this is the tower function, could be called for inference
if
ctx
.
has_own_variables
:
if
ctx
.
has_own_variables
:
varlist
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
varlist
=
ctx
.
get_collection_in_tower
(
tf
v1
.
GraphKeys
.
TRAINABLE_VARIABLES
)
else
:
else
:
varlist
=
tf
.
trainable_variables
()
varlist
=
tf
v1
.
trainable_variables
()
opt
=
get_opt_fn
()
opt
=
get_opt_fn
()
if
is_tfv2
()
and
isinstance
(
opt
,
tf
.
optimizers
.
Optimizer
):
grads
=
opt
.
get_gradients
(
cost
,
varlist
)
grads
=
list
(
zip
(
grads
,
varlist
))
else
:
grads
=
opt
.
compute_gradients
(
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
cost
,
var_list
=
varlist
,
gate_gradients
=
self
.
GATE_GRADIENTS
,
gate_gradients
=
self
.
GATE_GRADIENTS
,
...
...
tensorpack/utils/argtools.py
View file @
505e28eb
...
@@ -52,7 +52,7 @@ def graph_memoized(func):
...
@@ -52,7 +52,7 @@ def graph_memoized(func):
"""
"""
# TODO it keeps the graph alive
# TODO it keeps the graph alive
import
tensorflow
as
tf
from
..compat
import
tfv1
GRAPH_ARG_NAME
=
'__IMPOSSIBLE_NAME_FOR_YOU__'
GRAPH_ARG_NAME
=
'__IMPOSSIBLE_NAME_FOR_YOU__'
@
memoized
@
memoized
...
@@ -63,7 +63,7 @@ def graph_memoized(func):
...
@@ -63,7 +63,7 @@ def graph_memoized(func):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
assert
GRAPH_ARG_NAME
not
in
kwargs
,
"No Way!!"
assert
GRAPH_ARG_NAME
not
in
kwargs
,
"No Way!!"
graph
=
tf
.
get_default_graph
()
graph
=
tf
v1
.
get_default_graph
()
kwargs
[
GRAPH_ARG_NAME
]
=
graph
kwargs
[
GRAPH_ARG_NAME
]
=
graph
return
func_with_graph_arg
(
*
args
,
**
kwargs
)
return
func_with_graph_arg
(
*
args
,
**
kwargs
)
return
wrapper
return
wrapper
...
...
tox.ini
View file @
505e28eb
...
@@ -5,7 +5,7 @@ ignore = E265,E741,E742,E743,W504,W605
...
@@ -5,7 +5,7 @@ ignore = E265,E741,E742,E743,W504,W605
exclude
=
.git,
exclude
=
.git,
__init__.py,
__init__.py,
setup.py,
setup.py,
tensorpack/
train/eager.py
,
tensorpack/
compat/*
,
docs,
docs,
examples,
examples,
docs/conf.py
docs/conf.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment