Commit 505e28eb authored by Yuxin Wu's avatar Yuxin Wu

Backport TensorSpec; tf=tf.compat.v1 in many files.

parent e4941595
...@@ -16,9 +16,15 @@ If you think: ...@@ -16,9 +16,15 @@ If you think:
Then it is a good time to open an issue. Then it is a good time to open an issue.
## How to print/dump intermediate results in training ## How to print/dump intermediate results during training
1. Learn `tf.Print`. 1. Learn `tf.Print`. Most of the times, adding one line in between:
```python
tensor = obtain_a_tensor()
tensor = tf.Print(tensor, [tf.shape(tensor), tensor], tensor.name, summarize=100)
use_the_tensor(tensor)
```
is sufficient.
2. Know [DumpTensors](../modules/callbacks.html#tensorpack.callbacks.DumpTensors), 2. Know [DumpTensors](../modules/callbacks.html#tensorpack.callbacks.DumpTensors),
[ProcessTensors](../modules/callbacks.html#tensorpack.callbacks.ProcessTensors) callbacks. [ProcessTensors](../modules/callbacks.html#tensorpack.callbacks.ProcessTensors) callbacks.
......
...@@ -21,8 +21,8 @@ class Model(ModelDesc): ...@@ -21,8 +21,8 @@ class Model(ModelDesc):
""" """
Define all the inputs (with type, shape, name) that the graph will need. Define all the inputs (with type, shape, name) that the graph will need.
""" """
return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [tf.TensorSpec((None, IMAGE_SIZE, IMAGE_SIZE), tf.float32, 'input'),
tf.placeholder(tf.int32, (None,), 'label')] tf.TensorSpec((None,), tf.int32, 'label')]
def build_graph(self, image, label): def build_graph(self, image, label):
"""This function should build the model which takes the input variables """This function should build the model which takes the input variables
...@@ -51,7 +51,7 @@ class Model(ModelDesc): ...@@ -51,7 +51,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
correct = tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32, name='correct') correct = tf.cast(tf.nn.in_top_k(predictions=logits, targets=label, k=1), tf.float32, name='correct')
accuracy = tf.reduce_mean(correct, name='accuracy') accuracy = tf.reduce_mean(correct, name='accuracy')
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically # This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from abc import ABCMeta from abc import ABCMeta
import six import six
import tensorflow as tf from ..compat import tfv1 as tf
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
import numpy as np import numpy as np
import os import os
import tensorflow as tf
from six.moves import zip from six.moves import zip
from ..compat import tfv1 as tf
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from .base import Callback from .base import Callback
......
...@@ -6,7 +6,7 @@ import traceback ...@@ -6,7 +6,7 @@ import traceback
from contextlib import contextmanager from contextlib import contextmanager
from time import time as timer from time import time as timer
import six import six
import tensorflow as tf from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from ..utils.utils import humanize_time_delta from ..utils.utils import humanize_time_delta
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import tfv1 from ..compat import tfv1
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .base import Callback from .base import Callback
......
...@@ -5,15 +5,14 @@ ...@@ -5,15 +5,14 @@
import itertools import itertools
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
import tqdm import tqdm
from six.moves import range from six.moves import range
from tensorflow.python.training.monitored_session import _HookedSession as HookedSession from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
from ..compat import tfv1 as tf
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..tfutils.common import tfv1
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from .base import Callback from .base import Callback
...@@ -28,7 +27,7 @@ def _device_from_int(dev): ...@@ -28,7 +27,7 @@ def _device_from_int(dev):
return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0' return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0'
class InferencerToHook(tfv1.train.SessionRunHook): class InferencerToHook(tf.train.SessionRunHook):
def __init__(self, inf, fetches): def __init__(self, inf, fetches):
self._inf = inf self._inf = inf
self._fetches = fetches self._fetches = fetches
......
...@@ -12,8 +12,8 @@ import time ...@@ -12,8 +12,8 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
import six import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..libinfo import __git_version__ from ..libinfo import __git_version__
from ..tfutils.summary import create_image_summary, create_scalar_summary from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger from ..utils import logger
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import os import os
from datetime import datetime from datetime import datetime
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from .base import Callback from .base import Callback
...@@ -40,8 +40,8 @@ class ModelSaver(Callback): ...@@ -40,8 +40,8 @@ class ModelSaver(Callback):
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = logger.get_logger_dir() checkpoint_dir = logger.get_logger_dir()
if checkpoint_dir is not None: if checkpoint_dir is not None:
if not tf.gfile.IsDirectory(checkpoint_dir): if not tf.gfile.IsDirectory(checkpoint_dir): # v2: tf.io.gfile.isdir
tf.gfile.MakeDirs(checkpoint_dir) tf.gfile.MakeDirs(checkpoint_dir) # v2: tf.io.gfile.makedirs
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
def _setup_graph(self): def _setup_graph(self):
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
""" Some common step callbacks. """ """ Some common step callbacks. """
import tensorflow as tf
import tqdm import tqdm
from six.moves import zip from six.moves import zip
from ..compat import tfv1 as tf
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import numpy as np import numpy as np
from collections import deque from collections import deque
import tensorflow as tf
from ..compat import tfv1 as tf
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
......
#!/usr/bin/env python
import tensorflow as tf
def backport_tensor_spec():
if hasattr(tf, 'TensorSpec'):
return tf.TensorSpec
try:
# available since 1.7
from tensorflow.python.framework.tensor_spec import TensorSpec
except ImportError:
pass
else:
tf.TensorSpec = TensorSpec
return TensorSpec
from .tensor_spec import TensorSpec
tf.TensorSpec = TensorSpec
return TensorSpec
def is_tfv2():
try:
from tensorflow.python import tf2
return tf2.enabled()
except Exception:
return False
if is_tfv2():
tfv1 = tf.compat.v1
if not hasattr(tf, 'layers'):
# promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
tf.layers = tf.keras.layers
else:
tfv1 = tf
"""
Copied from tensorflow/python/framework/tensor_spec.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
class TensorSpec(object):
"""Describes a tf.Tensor.
Metadata for describing the `tf.Tensor` objects accepted or returned
by some TensorFlow APIs.
"""
__slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
def __init__(self, shape, dtype=dtypes.float32, name=None):
"""Creates a TensorSpec.
Args:
shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
dtype: Value convertible to `tf.DType`. The type of the tensor values.
name: Optional name for the Tensor.
Raises:
TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
not convertible to a `tf.DType`.
"""
self._shape = tensor_shape.TensorShape(shape)
try:
self._shape_tuple = tuple(self.shape.as_list())
except ValueError:
self._shape_tuple = None
self._dtype = dtypes.as_dtype(dtype)
self._name = name
@classmethod
def from_spec(cls, spec, name=None):
return cls(spec.shape, spec.dtype, name or spec.name)
@classmethod
def from_tensor(cls, tensor, name=None):
if isinstance(tensor, ops.EagerTensor):
return TensorSpec(tensor.shape, tensor.dtype, name)
elif isinstance(tensor, ops.Tensor):
return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
else:
raise ValueError("`tensor` should be a tf.Tensor")
@property
def shape(self):
"""Returns the `TensorShape` that represents the shape of the tensor."""
return self._shape
@property
def dtype(self):
"""Returns the `dtype` of elements in the tensor."""
return self._dtype
@property
def name(self):
"""Returns the (optionally provided) name of the described tensor."""
return self._name
def is_compatible_with(self, spec_or_tensor):
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
Two tensors are considered compatible if they have the same dtype
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
Args:
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
Returns:
True if spec_or_tensor is compatible with self.
"""
return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
self._shape.is_compatible_with(spec_or_tensor.shape))
def __repr__(self):
return "TensorSpec(shape={}, dtype={}, name={})".format(
self.shape, repr(self.dtype), repr(self.name))
def __hash__(self):
return hash((self._shape_tuple, self.dtype))
def __eq__(self, other):
return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access
and self.dtype == other.dtype
and self._name == other._name) # pylint: disable=protected-access
def __ne__(self, other):
return not self == other
def __reduce__(self):
return TensorSpec, (self._shape, self._dtype, self._name)
...@@ -7,13 +7,12 @@ import tensorflow as tf ...@@ -7,13 +7,12 @@ import tensorflow as tf
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_tuple
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized_method from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..compat import backport_tensor_spec, tfv1
if get_tf_version_tuple() >= (1, 7): TensorSpec = backport_tensor_spec()
from tensorflow.python.framework.tensor_spec import TensorSpec
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
...@@ -49,8 +48,8 @@ class InputDesc( ...@@ -49,8 +48,8 @@ class InputDesc(
Returns: Returns:
tf.Tensor: tf.Tensor:
""" """
with tf.name_scope(None): # clear any name scope it might get called in with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tf.placeholder( ret = tfv1.placeholder(
self.type, shape=self.shape, name=self.name) self.type, shape=self.shape, name=self.name)
self._register_cached_placeholder(ret) self._register_cached_placeholder(ret)
return ret return ret
...@@ -63,7 +62,7 @@ class InputDesc( ...@@ -63,7 +62,7 @@ class InputDesc(
Returns: Returns:
tf.Tensor: tf.Tensor:
""" """
g = tf.get_default_graph() g = tfv1.get_default_graph()
if g in self._cached_placeholder: if g in self._cached_placeholder:
return self._cached_placeholder[g] return self._cached_placeholder[g]
else: else:
......
...@@ -8,6 +8,7 @@ from itertools import chain ...@@ -8,6 +8,7 @@ from itertools import chain
import tensorflow as tf import tensorflow as tf
from six.moves import range, zip from six.moves import range, zip
from ..compat import tfv1
from ..callbacks.base import Callback, CallbackFactory from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
from ..dataflow import DataFlow, MapData, RepeatedData from ..dataflow import DataFlow, MapData, RepeatedData
...@@ -84,7 +85,7 @@ class FeedInput(InputSource): ...@@ -84,7 +85,7 @@ class FeedInput(InputSource):
dp = next(self._itr) dp = next(self._itr)
assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!" assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!"
feed = _make_feeds(self._placeholders, dp) feed = _make_feeds(self._placeholders, dp)
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed) return tfv1.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self): def _reset(self):
self._itr = self._ds.__iter__() self._itr = self._ds.__iter__()
...@@ -228,9 +229,9 @@ class QueueInput(FeedfreeInput): ...@@ -228,9 +229,9 @@ class QueueInput(FeedfreeInput):
""" """
self.thread.pause() # pause enqueue self.thread.pause() # pause enqueue
opt = tf.RunOptions() opt = tfv1.RunOptions()
opt.timeout_in_ms = 2000 # 2s opt.timeout_in_ms = 2000 # 2s
sess = tf.get_default_session() sess = tfv1.get_default_session()
# dequeue until empty # dequeue until empty
try: try:
while True: while True:
...@@ -304,7 +305,7 @@ class BatchQueueInput(QueueInput): ...@@ -304,7 +305,7 @@ class BatchQueueInput(QueueInput):
# prepare placeholders without the first dimension # prepare placeholders without the first dimension
placehdrs_nobatch = [] placehdrs_nobatch = []
for p in self.input_placehdrs: for p in self.input_placehdrs:
placehdrs_nobatch.append(tf.placeholder( placehdrs_nobatch.append(tfv1.placeholder(
dtype=p.dtype, shape=p.get_shape().as_list()[1:], dtype=p.dtype, shape=p.get_shape().as_list()[1:],
name=get_op_tensor_name(p.name)[0] + '-nobatch')) name=get_op_tensor_name(p.name)[0] + '-nobatch'))
...@@ -546,7 +547,7 @@ class StagingInput(FeedfreeInput): ...@@ -546,7 +547,7 @@ class StagingInput(FeedfreeInput):
unstage_ops = self._input._get_unstage_ops() unstage_ops = self._input._get_unstage_ops()
unstage_op = tf.group(*unstage_ops, name='unstage_all') unstage_op = tf.group(*unstage_ops, name='unstage_all')
self._check_dependency_op = unstage_ops[0] self._check_dependency_op = unstage_ops[0]
self.fetches = tf.train.SessionRunArgs( self.fetches = tfv1.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op]) fetches=[self.stage_op, unstage_op])
def _prefill(self, sess): def _prefill(self, sess):
......
...@@ -52,7 +52,7 @@ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0' ...@@ -52,7 +52,7 @@ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
try: try:
import tensorflow as tf # noqa import tensorflow as tf # noqa
_version = tf.__version__.split('.') _version = tf.__version__.split('.')
assert int(_version[0]) >= 1 and int(_version[1]) >= 3, "TF>=1.3 is required!" assert (int(_version[0]), int(_version[1])) >= (1, 3), "TF>=1.3 is required!"
_HAS_TF = True _HAS_TF = True
except ImportError: except ImportError:
print("Failed to import tensorflow.") print("Failed to import tensorflow.")
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import re import re
import six import six
import tensorflow as tf from ..compat import tfv1 as tf # this should be avoided first in model code
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: conv2d.py # File: conv2d.py
import tensorflow as tf from ..compat import tfv1 as tf # this should be avoided first in model code
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import get_data_format, shape2d, shape4d, log_once from ..utils.argtools import get_data_format, shape2d, shape4d, log_once
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import numpy as np import numpy as np
import tensorflow as tf from ..compat import tfv1 as tf # this should be avoided first in model code
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: layer_norm.py # File: layer_norm.py
import tensorflow as tf from ..compat import tfv1 as tf # this should be avoided first in model code
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: pool.py # File: pool.py
import numpy as np import numpy as np
import tensorflow as tf from ..compat import tfv1 as tf # this should be avoided first in model code
from ..utils.argtools import get_data_format, shape2d from ..utils.argtools import get_data_format, shape2d
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
......
...@@ -8,6 +8,7 @@ from functools import wraps ...@@ -8,6 +8,7 @@ from functools import wraps
import six import six
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str from ..tfutils.model_utils import get_shape_str
from ..utils import logger from ..utils import logger
...@@ -117,7 +118,7 @@ def layer_register( ...@@ -117,7 +118,7 @@ def layer_register(
# del actual_args[k] # del actual_args[k]
if name is not None: # use scope if name is not None: # use scope
with tf.variable_scope(name) as scope: with tfv1.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics # this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name = re.sub('tower[0-9]+/', '', scope.name) scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import re import re
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
...@@ -60,13 +61,13 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -60,13 +61,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
# If vars are shared, regularize all of them # If vars are shared, regularize all of them
# If vars are replicated, only regularize those in the current tower # If vars are replicated, only regularize those in the current tower
if ctx.has_own_variables: if ctx.has_own_variables:
params = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES) params = ctx.get_collection_in_tower(tfv1.GraphKeys.TRAINABLE_VARIABLES)
else: else:
params = tf.trainable_variables() params = tfv1.trainable_variables()
names = [] names = []
with tf.name_scope(name + '_internals'): with tfv1.name_scope(name + '_internals'):
costs = [] costs = []
for p in params: for p in params:
para_name = p.op.name para_name = p.op.name
...@@ -119,9 +120,9 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -119,9 +120,9 @@ def regularize_cost_from_collection(name='regularize_cost'):
# NOTE: this collection doesn't always grow with towers. # NOTE: this collection doesn't always grow with towers.
# It only grows with actual variable creation, but not get_variable call. # It only grows with actual variable creation, but not get_variable call.
if ctx.has_own_variables: # be careful of the first tower (name='') if ctx.has_own_variables: # be careful of the first tower (name='')
losses = ctx.get_collection_in_tower(tf.GraphKeys.REGULARIZATION_LOSSES) losses = ctx.get_collection_in_tower(tfv1.GraphKeys.REGULARIZATION_LOSSES)
else: else:
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) losses = tfv1.get_collection(tfv1.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0: if len(losses) > 0:
logger.info("regularize_cost_from_collection() found {} regularizers " logger.info("regularize_cost_from_collection() found {} regularizers "
"in REGULARIZATION_LOSSES collection.".format(len(losses))) "in REGULARIZATION_LOSSES collection.".format(len(losses)))
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
import six import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
......
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
import tensorflow as tf import tensorflow as tf
from six.moves import map from six.moves import map
from ..compat import tfv1
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_tf_version_tuple' 'get_tf_version_tuple',
# 'get_op_tensor_name', # 'get_op_tensor_name',
# 'get_tensors_by_names', # 'get_tensors_by_names',
# 'get_op_or_tensor_by_name', # 'get_op_or_tensor_by_name',
...@@ -30,7 +31,7 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -30,7 +31,7 @@ def get_default_sess_config(mem_fraction=0.99):
Returns: Returns:
tf.ConfigProto: the config to use. tf.ConfigProto: the config to use.
""" """
conf = tf.ConfigProto() conf = tfv1.ConfigProto()
conf.allow_soft_placement = True conf.allow_soft_placement = True
# conf.log_device_placement = True # conf.log_device_placement = True
...@@ -64,9 +65,9 @@ def get_global_step_var(): ...@@ -64,9 +65,9 @@ def get_global_step_var():
Returns: Returns:
tf.Tensor: the global_step variable in the current graph. Create if doesn't exist. tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
""" """
scope = tf.VariableScope(reuse=False, name='') # the root vs scope = tfv1.VariableScope(reuse=False, name='') # the root vs
with tf.variable_scope(scope): with tfv1.variable_scope(scope):
var = tf.train.get_or_create_global_step() var = tfv1.train.get_or_create_global_step()
return var return var
...@@ -78,8 +79,8 @@ def get_global_step_value(): ...@@ -78,8 +79,8 @@ def get_global_step_value():
Has to be called under a default session. Has to be called under a default session.
""" """
return tf.train.global_step( return tfv1.train.global_step(
tf.get_default_session(), tfv1.get_default_session(),
get_global_step_var()) get_global_step_var())
...@@ -108,7 +109,7 @@ def get_tensors_by_names(names): ...@@ -108,7 +109,7 @@ def get_tensors_by_names(names):
names (list): names (list):
""" """
ret = [] ret = []
G = tf.get_default_graph() G = tfv1.get_default_graph()
for n in names: for n in names:
opn, varn = get_op_tensor_name(n) opn, varn = get_op_tensor_name(n)
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
...@@ -125,7 +126,7 @@ def get_op_or_tensor_by_name(name): ...@@ -125,7 +126,7 @@ def get_op_or_tensor_by_name(name):
Raises: Raises:
KeyError, if the name doesn't exist KeyError, if the name doesn't exist
""" """
G = tf.get_default_graph() G = tfv1.get_default_graph()
def f(n): def f(n):
if len(n) >= 3 and n[-2] == ':': if len(n) >= 3 and n[-2] == ':':
...@@ -140,7 +141,7 @@ def get_op_or_tensor_by_name(name): ...@@ -140,7 +141,7 @@ def get_op_or_tensor_by_name(name):
def gpu_available_in_session(): def gpu_available_in_session():
sess = tf.get_default_session() sess = tfv1.get_default_session()
for dev in sess.list_devices(): for dev in sess.list_devices():
if dev.device_type.lower() == 'gpu': if dev.device_type.lower() == 'gpu':
return True return True
...@@ -152,17 +153,3 @@ def get_tf_version_tuple(): ...@@ -152,17 +153,3 @@ def get_tf_version_tuple():
Return TensorFlow version as a 2-element tuple (for comparison). Return TensorFlow version as a 2-element tuple (for comparison).
""" """
return tuple(map(int, tf.__version__.split('.')[:2])) return tuple(map(int, tf.__version__.split('.')[:2]))
def is_tf2():
try:
from tensorflow.python import tf2
return tf2.enabled()
except Exception:
return False
if is_tf2():
tfv1 = tf.compat.v1
else:
tfv1 = tf
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
...@@ -33,6 +32,7 @@ def dependency_of_targets(targets, op): ...@@ -33,6 +32,7 @@ def dependency_of_targets(targets, op):
op = op.op op = op.op
assert isinstance(op, tf.Operation), op assert isinstance(op, tf.Operation), op
from tensorflow.contrib.graph_editor import get_backward_walk_ops
# alternative implementation can use graph_util.extract_sub_graph # alternative implementation can use graph_util.extract_sub_graph
dependent_ops = get_backward_walk_ops(targets, control_inputs=True) dependent_ops = get_backward_walk_ops(targets, control_inputs=True)
return op in dependent_ops return op in dependent_ops
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: model_utils.py # File: model_utils.py
# Author: tensorpack contributors # Author: tensorpack contributors
import tensorflow as tf from ..compat import tfv1 as tf
from tabulate import tabulate from tabulate import tabulate
from termcolor import colored from termcolor import colored
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple, tfv1 from ..tfutils.common import get_tf_version_tuple
from ..compat import tfv1
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .gradproc import FilterNoneGrad, GradientProcessor from .gradproc import FilterNoneGrad, GradientProcessor
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
......
...@@ -2,10 +2,7 @@ ...@@ -2,10 +2,7 @@
# File: sesscreate.py # File: sesscreate.py
import tensorflow as tf from ..compat import tfv1 as tf, is_tfv2
from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..tfutils.common import tfv1
from ..utils import logger from ..utils import logger
from .common import get_default_sess_config from .common import get_default_sess_config
...@@ -20,7 +17,7 @@ A SessionCreator should: ...@@ -20,7 +17,7 @@ A SessionCreator should:
""" """
class NewSessionCreator(tfv1.train.SessionCreator): class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', config=None): def __init__(self, target='', config=None):
""" """
Args: Args:
...@@ -59,11 +56,15 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.") ...@@ -59,11 +56,15 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return False return False
def run(op): def run(op):
if not is_tfv2():
from tensorflow.contrib.graph_editor import get_backward_walk_ops
deps = get_backward_walk_ops(op, control_inputs=True) deps = get_backward_walk_ops(op, control_inputs=True)
for dep_op in deps: for dep_op in deps:
if blocking_op(dep_op): if blocking_op(dep_op):
logger.warn( logger.warn(
"Initializer '{}' depends on a blocking op '{}'. This initializer is likely to hang!".format( "Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!".format(
op.name, dep_op.name)) op.name, dep_op.name))
sess.run(op) sess.run(op)
...@@ -73,7 +74,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.") ...@@ -73,7 +74,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return sess return sess
class ReuseSessionCreator(tfv1.train.SessionCreator): class ReuseSessionCreator(tf.train.SessionCreator):
""" """
Returns an existing session. Returns an existing session.
""" """
...@@ -88,7 +89,7 @@ class ReuseSessionCreator(tfv1.train.SessionCreator): ...@@ -88,7 +89,7 @@ class ReuseSessionCreator(tfv1.train.SessionCreator):
return self.sess return self.sess
class SessionCreatorAdapter(tfv1.train.SessionCreator): class SessionCreatorAdapter(tf.train.SessionCreator):
""" """
Apply a function on the output of a SessionCreator. Can be used to create a debug session. Apply a function on the output of a SessionCreator. Can be used to create a debug session.
""" """
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
import re import re
from contextlib import contextmanager from contextlib import contextmanager
import six import six
import tensorflow as tf
from six.moves import range from six.moves import range
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1
from ..utils.develop import deprecated from ..utils.develop import deprecated
__all__ = ['print_stat', 'rms'] __all__ = ['print_stat', 'rms']
...@@ -30,7 +31,7 @@ def rms(x, name=None): ...@@ -30,7 +31,7 @@ def rms(x, name=None):
""" """
if name is None: if name is None:
name = x.op.name + '/rms' name = x.op.name + '/rms'
with tf.name_scope(None): # name already contains the scope with tfv1.name_scope(None): # name already contains the scope
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
from abc import ABCMeta, abstractmethod, abstractproperty from abc import ABCMeta, abstractmethod, abstractproperty
import six import six
import tensorflow as tf
from six.moves import zip from six.moves import zip
from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# Credit: Qinyao He # Credit: Qinyao He
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
from ..compat import tfv1 as tf
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
__all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables'] __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
......
...@@ -8,6 +8,7 @@ import six ...@@ -8,6 +8,7 @@ import six
import tensorflow as tf import tensorflow as tf
from six.moves import range from six.moves import range
from ..compat import tfv1
from ..callbacks import Callback, Callbacks, Monitors, MonitorBase from ..callbacks import Callback, Callbacks, Monitors, MonitorBase
from ..callbacks.steps import MaintainStepCounter from ..callbacks.steps import MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
...@@ -222,7 +223,7 @@ class Trainer(object): ...@@ -222,7 +223,7 @@ class Trainer(object):
session_creator (tf.train.SessionCreator): session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit): session_init (sessinit.SessionInit):
""" """
assert isinstance(session_creator, tf.train.SessionCreator), session_creator assert isinstance(session_creator, tfv1.train.SessionCreator), session_creator
assert isinstance(session_init, SessionInit), session_init assert isinstance(session_init, SessionInit), session_init
session_init._setup_graph() session_init._setup_graph()
...@@ -250,7 +251,7 @@ class Trainer(object): ...@@ -250,7 +251,7 @@ class Trainer(object):
which can be useful when the training is not done by a single `train_op`. which can be useful when the training is not done by a single `train_op`.
""" """
hooks = self._callbacks.get_hooks() hooks = self._callbacks.get_hooks()
self.hooked_sess = tf.train.MonitoredSession( self.hooked_sess = tfv1.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks) session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@call_only_once @call_only_once
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: interface.py # File: interface.py
import tensorflow as tf from ..compat import tfv1
from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput
from ..utils import logger from ..utils import logger
from ..compat import is_tfv2
from .config import TrainConfig from .config import TrainConfig
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
from .trainers import SimpleTrainer from .trainers import SimpleTrainer
...@@ -71,6 +71,9 @@ def launch_train_with_config(config, trainer): ...@@ -71,6 +71,9 @@ def launch_train_with_config(config, trainer):
launch_train_with_config( launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu')) config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
""" """
if is_tfv2():
tfv1.disable_eager_execution()
assert isinstance(trainer, SingleCostTrainer), trainer assert isinstance(trainer, SingleCostTrainer), trainer
assert isinstance(config, TrainConfig), config assert isinstance(config, TrainConfig), config
assert config.model is not None assert config.model is not None
...@@ -99,7 +102,7 @@ def launch_train_with_config(config, trainer): ...@@ -99,7 +102,7 @@ def launch_train_with_config(config, trainer):
def _check_unused_regularization(): def _check_unused_regularization():
coll = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) coll = tfv1.get_collection(tfv1.GraphKeys.REGULARIZATION_LOSSES)
unconsumed_reg = [] unconsumed_reg = []
for c in coll: for c in coll:
if len(c.consumers()) == 0: if len(c.consumers()) == 0:
......
...@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod ...@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
import six import six
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1, is_tfv2
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
...@@ -126,7 +127,7 @@ class TowerTrainer(Trainer): ...@@ -126,7 +127,7 @@ class TowerTrainer(Trainer):
input.setup(self.inputs_desc) input.setup(self.inputs_desc)
vs_name = self._vs_name_for_predictor(device_id) vs_name = self._vs_name_for_predictor(device_id)
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \ with tfv1.variable_scope(tfv1.get_variable_scope(), reuse=True), \
tf.device(device), PredictTowerContext( tf.device(device), PredictTowerContext(
tower_name, vs_name=vs_name): tower_name, vs_name=vs_name):
logger.info("Building graph for predict tower '{}' on device {} {}...".format( logger.info("Building graph for predict tower '{}' on device {} {}...".format(
...@@ -254,10 +255,14 @@ class SingleCostTrainer(TowerTrainer): ...@@ -254,10 +255,14 @@ class SingleCostTrainer(TowerTrainer):
return None # this is the tower function, could be called for inference return None # this is the tower function, could be called for inference
if ctx.has_own_variables: if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES) varlist = ctx.get_collection_in_tower(tfv1.GraphKeys.TRAINABLE_VARIABLES)
else: else:
varlist = tf.trainable_variables() varlist = tfv1.trainable_variables()
opt = get_opt_fn() opt = get_opt_fn()
if is_tfv2() and isinstance(opt, tf.optimizers.Optimizer):
grads = opt.get_gradients(cost, varlist)
grads = list(zip(grads, varlist))
else:
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, var_list=varlist, cost, var_list=varlist,
gate_gradients=self.GATE_GRADIENTS, gate_gradients=self.GATE_GRADIENTS,
......
...@@ -52,7 +52,7 @@ def graph_memoized(func): ...@@ -52,7 +52,7 @@ def graph_memoized(func):
""" """
# TODO it keeps the graph alive # TODO it keeps the graph alive
import tensorflow as tf from ..compat import tfv1
GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__' GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__'
@memoized @memoized
...@@ -63,7 +63,7 @@ def graph_memoized(func): ...@@ -63,7 +63,7 @@ def graph_memoized(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
assert GRAPH_ARG_NAME not in kwargs, "No Way!!" assert GRAPH_ARG_NAME not in kwargs, "No Way!!"
graph = tf.get_default_graph() graph = tfv1.get_default_graph()
kwargs[GRAPH_ARG_NAME] = graph kwargs[GRAPH_ARG_NAME] = graph
return func_with_graph_arg(*args, **kwargs) return func_with_graph_arg(*args, **kwargs)
return wrapper return wrapper
......
...@@ -5,7 +5,7 @@ ignore = E265,E741,E742,E743,W504,W605 ...@@ -5,7 +5,7 @@ ignore = E265,E741,E742,E743,W504,W605
exclude = .git, exclude = .git,
__init__.py, __init__.py,
setup.py, setup.py,
tensorpack/train/eager.py, tensorpack/compat/*,
docs, docs,
examples, examples,
docs/conf.py docs/conf.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment