Commit 2ce43d70 authored by Yuxin Wu's avatar Yuxin Wu

make code importable under tf2

parent 7b4980c9
......@@ -80,11 +80,13 @@ class RunUpdateOps(RunOp):
each `sess.run` call.
"""
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def __init__(self, collection=None):
"""
Args:
collection (str): collection of ops to run. Defaults to ``tf.GraphKeys.UPDATE_OPS``
"""
if collection is None:
collection = tf.GraphKeys.UPDATE_OPS
name = 'UPDATE_OPS' if collection == tf.GraphKeys.UPDATE_OPS else collection
def f():
......
......@@ -6,12 +6,13 @@
import tensorflow as tf
from ..tfutils.common import tfv1
from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback']
class CallbackToHook(tf.train.SessionRunHook):
class CallbackToHook(tfv1.train.SessionRunHook):
""" This is only for internal implementation of
before_run/after_run callbacks.
You shouldn't need to use this.
......
......@@ -13,6 +13,7 @@ from tensorflow.python.training.monitored_session import _HookedSession as Hooke
from ..dataflow.base import DataFlow
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..tfutils.tower import PredictTowerContext
from ..tfutils.common import tfv1
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from .base import Callback
......@@ -27,7 +28,7 @@ def _device_from_int(dev):
return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0'
class InferencerToHook(tf.train.SessionRunHook):
class InferencerToHook(tfv1.train.SessionRunHook):
def __init__(self, inf, fetches):
self._inf = inf
self._fetches = fetches
......
......@@ -20,7 +20,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None,
var_collections=[tf.GraphKeys.GLOBAL_VARIABLES]):
var_collections=None):
"""
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
......@@ -29,6 +29,8 @@ class ModelSaver(Callback):
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save.
"""
if var_collections is None:
var_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
self._max_to_keep = max_to_keep
self._keep_every_n_hours = keep_checkpoint_every_n_hours
......
......@@ -116,7 +116,7 @@ class MergeAllSummaries_RunWithOp(Callback):
self.trainer.monitors.put_summary(summary)
def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
def MergeAllSummaries(period=0, run_alone=False, key=None):
"""
This callback is enabled by default.
Evaluate all summaries by `tf.summary.merge_all`, and write them to logs.
......@@ -133,6 +133,8 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
Default is ``tf.GraphKeys.SUMMARIES``.
"""
if key is None:
key = tf.GraphKeys.SUMMARIES
period = int(period)
if run_alone:
return MergeAllSummaries_RunAlone(period, key)
......
......@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
Args:
regex (str): a regex to match variable names, e.g. "conv.*/W"
func: the regularization function, which takes a tensor and returns a scalar tensor.
E.g., ``tf.contrib.layers.l2_regularizer``.
E.g., ``tf.nn.l2_loss, tf.contrib.layers.l1_regularizer(0.001)``.
Returns:
tf.Tensor: a scalar, the total regularization cost.
......
......@@ -150,11 +150,25 @@ def gpu_available_in_session():
@deprecated("Use get_tf_version_tuple instead.", "2019-01-31")
def get_tf_version_number():
return float('.'.join(tf.VERSION.split('.')[:2]))
return float('.'.join(tf.__version__.split('.')[:2]))
def get_tf_version_tuple():
"""
Return TensorFlow version as a 2-element tuple (for comparison).
"""
return tuple(map(int, tf.VERSION.split('.')[:2]))
return tuple(map(int, tf.__version__.split('.')[:2]))
def is_tf2():
try:
from tensorflow.python import tf2
return tf2.enabled()
except Exception:
return False
if is_tf2():
tfv1 = tf.compat.v1
else:
tfv1 = tf
......@@ -5,7 +5,7 @@
from contextlib import contextmanager
import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.common import get_tf_version_tuple, tfv1
from ..utils.develop import HIDE_DOC
from .gradproc import FilterNoneGrad, GradientProcessor
......@@ -14,7 +14,7 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer',
'AccumGradOptimizer']
class ProxyOptimizer(tf.train.Optimizer):
class ProxyOptimizer(tfv1.train.Optimizer):
"""
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
......
......@@ -4,10 +4,11 @@
import tensorflow as tf
from ..tfutils.common import tfv1
from ..utils import logger
from .common import get_default_sess_config
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
__all__ = ['NewSessionCreator', 'ReuseSessionCreator']
"""
A SessionCreator should:
......@@ -18,7 +19,7 @@ A SessionCreator should:
"""
class NewSessionCreator(tf.train.SessionCreator):
class NewSessionCreator(tfv1.train.SessionCreator):
def __init__(self, target='', config=None):
"""
Args:
......@@ -47,7 +48,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return sess
class ReuseSessionCreator(tf.train.SessionCreator):
class ReuseSessionCreator(tfv1.train.SessionCreator):
def __init__(self, sess):
"""
Args:
......@@ -57,19 +58,3 @@ class ReuseSessionCreator(tf.train.SessionCreator):
def create_session(self):
return self.sess
class SessionCreatorAdapter(tf.train.SessionCreator):
def __init__(self, session_creator, func):
"""
Args:
session_creator (tf.train.SessionCreator): a session creator
func (tf.Session -> tf.Session): takes a session created by
``session_creator``, and return a new session to be returned by ``self.create_session``
"""
self._creator = session_creator
self._func = func
def create_session(self):
sess = self._creator.create_session()
return self._func(sess)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment