Commit 2ce43d70 authored by Yuxin Wu's avatar Yuxin Wu

make code importable under tf2

parent 7b4980c9
...@@ -80,11 +80,13 @@ class RunUpdateOps(RunOp): ...@@ -80,11 +80,13 @@ class RunUpdateOps(RunOp):
each `sess.run` call. each `sess.run` call.
""" """
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS): def __init__(self, collection=None):
""" """
Args: Args:
collection (str): collection of ops to run. Defaults to ``tf.GraphKeys.UPDATE_OPS`` collection (str): collection of ops to run. Defaults to ``tf.GraphKeys.UPDATE_OPS``
""" """
if collection is None:
collection = tf.GraphKeys.UPDATE_OPS
name = 'UPDATE_OPS' if collection == tf.GraphKeys.UPDATE_OPS else collection name = 'UPDATE_OPS' if collection == tf.GraphKeys.UPDATE_OPS else collection
def f(): def f():
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import tfv1
from .base import Callback from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback'] __all__ = ['CallbackToHook', 'HookToCallback']
class CallbackToHook(tf.train.SessionRunHook): class CallbackToHook(tfv1.train.SessionRunHook):
""" This is only for internal implementation of """ This is only for internal implementation of
before_run/after_run callbacks. before_run/after_run callbacks.
You shouldn't need to use this. You shouldn't need to use this.
......
...@@ -13,6 +13,7 @@ from tensorflow.python.training.monitored_session import _HookedSession as Hooke ...@@ -13,6 +13,7 @@ from tensorflow.python.training.monitored_session import _HookedSession as Hooke
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..tfutils.common import tfv1
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from .base import Callback from .base import Callback
...@@ -27,7 +28,7 @@ def _device_from_int(dev): ...@@ -27,7 +28,7 @@ def _device_from_int(dev):
return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0' return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0'
class InferencerToHook(tf.train.SessionRunHook): class InferencerToHook(tfv1.train.SessionRunHook):
def __init__(self, inf, fetches): def __init__(self, inf, fetches):
self._inf = inf self._inf = inf
self._fetches = fetches self._fetches = fetches
......
...@@ -20,7 +20,7 @@ class ModelSaver(Callback): ...@@ -20,7 +20,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10, def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5, keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None, checkpoint_dir=None,
var_collections=[tf.GraphKeys.GLOBAL_VARIABLES]): var_collections=None):
""" """
Args: Args:
max_to_keep (int): the same as in ``tf.train.Saver``. max_to_keep (int): the same as in ``tf.train.Saver``.
...@@ -29,6 +29,8 @@ class ModelSaver(Callback): ...@@ -29,6 +29,8 @@ class ModelSaver(Callback):
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``. checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save. var_collections (str or list of str): collection of the variables (or list of collections) to save.
""" """
if var_collections is None:
var_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
self._max_to_keep = max_to_keep self._max_to_keep = max_to_keep
self._keep_every_n_hours = keep_checkpoint_every_n_hours self._keep_every_n_hours = keep_checkpoint_every_n_hours
......
...@@ -116,7 +116,7 @@ class MergeAllSummaries_RunWithOp(Callback): ...@@ -116,7 +116,7 @@ class MergeAllSummaries_RunWithOp(Callback):
self.trainer.monitors.put_summary(summary) self.trainer.monitors.put_summary(summary)
def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES): def MergeAllSummaries(period=0, run_alone=False, key=None):
""" """
This callback is enabled by default. This callback is enabled by default.
Evaluate all summaries by `tf.summary.merge_all`, and write them to logs. Evaluate all summaries by `tf.summary.merge_all`, and write them to logs.
...@@ -133,6 +133,8 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES): ...@@ -133,6 +133,8 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`. key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
Default is ``tf.GraphKeys.SUMMARIES``. Default is ``tf.GraphKeys.SUMMARIES``.
""" """
if key is None:
key = tf.GraphKeys.SUMMARIES
period = int(period) period = int(period)
if run_alone: if run_alone:
return MergeAllSummaries_RunAlone(period, key) return MergeAllSummaries_RunAlone(period, key)
......
...@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
Args: Args:
regex (str): a regex to match variable names, e.g. "conv.*/W" regex (str): a regex to match variable names, e.g. "conv.*/W"
func: the regularization function, which takes a tensor and returns a scalar tensor. func: the regularization function, which takes a tensor and returns a scalar tensor.
E.g., ``tf.contrib.layers.l2_regularizer``. E.g., ``tf.nn.l2_loss, tf.contrib.layers.l1_regularizer(0.001)``.
Returns: Returns:
tf.Tensor: a scalar, the total regularization cost. tf.Tensor: a scalar, the total regularization cost.
......
...@@ -150,11 +150,25 @@ def gpu_available_in_session(): ...@@ -150,11 +150,25 @@ def gpu_available_in_session():
@deprecated("Use get_tf_version_tuple instead.", "2019-01-31") @deprecated("Use get_tf_version_tuple instead.", "2019-01-31")
def get_tf_version_number(): def get_tf_version_number():
return float('.'.join(tf.VERSION.split('.')[:2])) return float('.'.join(tf.__version__.split('.')[:2]))
def get_tf_version_tuple(): def get_tf_version_tuple():
""" """
Return TensorFlow version as a 2-element tuple (for comparison). Return TensorFlow version as a 2-element tuple (for comparison).
""" """
return tuple(map(int, tf.VERSION.split('.')[:2])) return tuple(map(int, tf.__version__.split('.')[:2]))
def is_tf2():
try:
from tensorflow.python import tf2
return tf2.enabled()
except Exception:
return False
if is_tf2():
tfv1 = tf.compat.v1
else:
tfv1 = tf
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple, tfv1
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .gradproc import FilterNoneGrad, GradientProcessor from .gradproc import FilterNoneGrad, GradientProcessor
...@@ -14,7 +14,7 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer', ...@@ -14,7 +14,7 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer',
'AccumGradOptimizer'] 'AccumGradOptimizer']
class ProxyOptimizer(tf.train.Optimizer): class ProxyOptimizer(tfv1.train.Optimizer):
""" """
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer` A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
""" """
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import tfv1
from ..utils import logger from ..utils import logger
from .common import get_default_sess_config from .common import get_default_sess_config
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter'] __all__ = ['NewSessionCreator', 'ReuseSessionCreator']
""" """
A SessionCreator should: A SessionCreator should:
...@@ -18,7 +19,7 @@ A SessionCreator should: ...@@ -18,7 +19,7 @@ A SessionCreator should:
""" """
class NewSessionCreator(tf.train.SessionCreator): class NewSessionCreator(tfv1.train.SessionCreator):
def __init__(self, target='', config=None): def __init__(self, target='', config=None):
""" """
Args: Args:
...@@ -47,7 +48,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.") ...@@ -47,7 +48,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return sess return sess
class ReuseSessionCreator(tf.train.SessionCreator): class ReuseSessionCreator(tfv1.train.SessionCreator):
def __init__(self, sess): def __init__(self, sess):
""" """
Args: Args:
...@@ -57,19 +58,3 @@ class ReuseSessionCreator(tf.train.SessionCreator): ...@@ -57,19 +58,3 @@ class ReuseSessionCreator(tf.train.SessionCreator):
def create_session(self): def create_session(self):
return self.sess return self.sess
class SessionCreatorAdapter(tf.train.SessionCreator):
def __init__(self, session_creator, func):
"""
Args:
session_creator (tf.train.SessionCreator): a session creator
func (tf.Session -> tf.Session): takes a session created by
``session_creator``, and return a new session to be returned by ``self.create_session``
"""
self._creator = session_creator
self._func = func
def create_session(self):
sess = self._creator.create_session()
return self._func(sess)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment