Commit 505e28eb authored by Yuxin Wu's avatar Yuxin Wu

Backport TensorSpec; tf=tf.compat.v1 in many files.

parent e4941595
......@@ -16,9 +16,15 @@ If you think:
Then it is a good time to open an issue.
## How to print/dump intermediate results in training
1. Learn `tf.Print`.
## How to print/dump intermediate results during training
1. Learn `tf.Print`. Most of the times, adding one line in between:
```python
tensor = obtain_a_tensor()
tensor = tf.Print(tensor, [tf.shape(tensor), tensor], tensor.name, summarize=100)
use_the_tensor(tensor)
```
is sufficient.
2. Know [DumpTensors](../modules/callbacks.html#tensorpack.callbacks.DumpTensors),
[ProcessTensors](../modules/callbacks.html#tensorpack.callbacks.ProcessTensors) callbacks.
......
......@@ -21,8 +21,8 @@ class Model(ModelDesc):
"""
Define all the inputs (with type, shape, name) that the graph will need.
"""
return [tf.placeholder(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
tf.placeholder(tf.int32, (None,), 'label')]
return [tf.TensorSpec((None, IMAGE_SIZE, IMAGE_SIZE), tf.float32, 'input'),
tf.TensorSpec((None,), tf.int32, 'label')]
def build_graph(self, image, label):
"""This function should build the model which takes the input variables
......@@ -51,7 +51,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
correct = tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32, name='correct')
correct = tf.cast(tf.nn.in_top_k(predictions=logits, targets=label, k=1), tf.float32, name='correct')
accuracy = tf.reduce_mean(correct, name='accuracy')
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
......
......@@ -4,7 +4,7 @@
from abc import ABCMeta
import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..tfutils.common import get_op_or_tensor_by_name
......
......@@ -6,9 +6,9 @@
import numpy as np
import os
import tensorflow as tf
from six.moves import zip
from ..compat import tfv1 as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from .base import Callback
......
......@@ -6,7 +6,7 @@ import traceback
from contextlib import contextmanager
from time import time as timer
import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils.utils import humanize_time_delta
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
from ..tfutils.common import tfv1
from ..compat import tfv1
from ..utils.develop import HIDE_DOC
from .base import Callback
......
......@@ -5,15 +5,14 @@
import itertools
import sys
from contextlib import contextmanager
import tensorflow as tf
import tqdm
from six.moves import range
from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
from ..compat import tfv1 as tf
from ..dataflow.base import DataFlow
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..tfutils.tower import PredictTowerContext
from ..tfutils.common import tfv1
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from .base import Callback
......@@ -28,7 +27,7 @@ def _device_from_int(dev):
return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0'
class InferencerToHook(tfv1.train.SessionRunHook):
class InferencerToHook(tf.train.SessionRunHook):
def __init__(self, inf, fetches):
self._inf = inf
self._fetches = fetches
......
......@@ -12,8 +12,8 @@ import time
from collections import defaultdict
from datetime import datetime
import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..libinfo import __git_version__
from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger
......
......@@ -4,8 +4,8 @@
import os
from datetime import datetime
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger
from .base import Callback
......@@ -40,8 +40,8 @@ class ModelSaver(Callback):
if checkpoint_dir is None:
checkpoint_dir = logger.get_logger_dir()
if checkpoint_dir is not None:
if not tf.gfile.IsDirectory(checkpoint_dir):
tf.gfile.MakeDirs(checkpoint_dir)
if not tf.gfile.IsDirectory(checkpoint_dir): # v2: tf.io.gfile.isdir
tf.gfile.MakeDirs(checkpoint_dir) # v2: tf.io.gfile.makedirs
self.checkpoint_dir = checkpoint_dir
def _setup_graph(self):
......
......@@ -3,10 +3,10 @@
""" Some common step callbacks. """
import tensorflow as tf
import tqdm
from six.moves import zip
from ..compat import tfv1 as tf
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
......
......@@ -4,8 +4,8 @@
import numpy as np
from collections import deque
import tensorflow as tf
from ..compat import tfv1 as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
......
#!/usr/bin/env python
import tensorflow as tf
def backport_tensor_spec():
if hasattr(tf, 'TensorSpec'):
return tf.TensorSpec
try:
# available since 1.7
from tensorflow.python.framework.tensor_spec import TensorSpec
except ImportError:
pass
else:
tf.TensorSpec = TensorSpec
return TensorSpec
from .tensor_spec import TensorSpec
tf.TensorSpec = TensorSpec
return TensorSpec
def is_tfv2():
try:
from tensorflow.python import tf2
return tf2.enabled()
except Exception:
return False
if is_tfv2():
tfv1 = tf.compat.v1
if not hasattr(tf, 'layers'):
# promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
tf.layers = tf.keras.layers
else:
tfv1 = tf
"""
Copied from tensorflow/python/framework/tensor_spec.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
class TensorSpec(object):
"""Describes a tf.Tensor.
Metadata for describing the `tf.Tensor` objects accepted or returned
by some TensorFlow APIs.
"""
__slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
def __init__(self, shape, dtype=dtypes.float32, name=None):
"""Creates a TensorSpec.
Args:
shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
dtype: Value convertible to `tf.DType`. The type of the tensor values.
name: Optional name for the Tensor.
Raises:
TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
not convertible to a `tf.DType`.
"""
self._shape = tensor_shape.TensorShape(shape)
try:
self._shape_tuple = tuple(self.shape.as_list())
except ValueError:
self._shape_tuple = None
self._dtype = dtypes.as_dtype(dtype)
self._name = name
@classmethod
def from_spec(cls, spec, name=None):
return cls(spec.shape, spec.dtype, name or spec.name)
@classmethod
def from_tensor(cls, tensor, name=None):
if isinstance(tensor, ops.EagerTensor):
return TensorSpec(tensor.shape, tensor.dtype, name)
elif isinstance(tensor, ops.Tensor):
return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
else:
raise ValueError("`tensor` should be a tf.Tensor")
@property
def shape(self):
"""Returns the `TensorShape` that represents the shape of the tensor."""
return self._shape
@property
def dtype(self):
"""Returns the `dtype` of elements in the tensor."""
return self._dtype
@property
def name(self):
"""Returns the (optionally provided) name of the described tensor."""
return self._name
def is_compatible_with(self, spec_or_tensor):
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
Two tensors are considered compatible if they have the same dtype
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
Args:
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
Returns:
True if spec_or_tensor is compatible with self.
"""
return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
self._shape.is_compatible_with(spec_or_tensor.shape))
def __repr__(self):
return "TensorSpec(shape={}, dtype={}, name={})".format(
self.shape, repr(self.dtype), repr(self.name))
def __hash__(self):
return hash((self._shape_tuple, self.dtype))
def __eq__(self, other):
return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access
and self.dtype == other.dtype
and self._name == other._name) # pylint: disable=protected-access
def __ne__(self, other):
return not self == other
def __reduce__(self):
return TensorSpec, (self._shape, self._dtype, self._name)
......@@ -7,13 +7,12 @@ import tensorflow as tf
from ..models.regularize import regularize_cost_from_collection
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_tuple
from ..utils import logger
from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated
from ..compat import backport_tensor_spec, tfv1
if get_tf_version_tuple() >= (1, 7):
from tensorflow.python.framework.tensor_spec import TensorSpec
TensorSpec = backport_tensor_spec()
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......@@ -49,8 +48,8 @@ class InputDesc(
Returns:
tf.Tensor:
"""
with tf.name_scope(None): # clear any name scope it might get called in
ret = tf.placeholder(
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
self.type, shape=self.shape, name=self.name)
self._register_cached_placeholder(ret)
return ret
......@@ -63,7 +62,7 @@ class InputDesc(
Returns:
tf.Tensor:
"""
g = tf.get_default_graph()
g = tfv1.get_default_graph()
if g in self._cached_placeholder:
return self._cached_placeholder[g]
else:
......
......@@ -8,6 +8,7 @@ from itertools import chain
import tensorflow as tf
from six.moves import range, zip
from ..compat import tfv1
from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
from ..dataflow import DataFlow, MapData, RepeatedData
......@@ -84,7 +85,7 @@ class FeedInput(InputSource):
dp = next(self._itr)
assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!"
feed = _make_feeds(self._placeholders, dp)
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
return tfv1.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self):
self._itr = self._ds.__iter__()
......@@ -228,9 +229,9 @@ class QueueInput(FeedfreeInput):
"""
self.thread.pause() # pause enqueue
opt = tf.RunOptions()
opt = tfv1.RunOptions()
opt.timeout_in_ms = 2000 # 2s
sess = tf.get_default_session()
sess = tfv1.get_default_session()
# dequeue until empty
try:
while True:
......@@ -304,7 +305,7 @@ class BatchQueueInput(QueueInput):
# prepare placeholders without the first dimension
placehdrs_nobatch = []
for p in self.input_placehdrs:
placehdrs_nobatch.append(tf.placeholder(
placehdrs_nobatch.append(tfv1.placeholder(
dtype=p.dtype, shape=p.get_shape().as_list()[1:],
name=get_op_tensor_name(p.name)[0] + '-nobatch'))
......@@ -546,7 +547,7 @@ class StagingInput(FeedfreeInput):
unstage_ops = self._input._get_unstage_ops()
unstage_op = tf.group(*unstage_ops, name='unstage_all')
self._check_dependency_op = unstage_ops[0]
self.fetches = tf.train.SessionRunArgs(
self.fetches = tfv1.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op])
def _prefill(self, sess):
......
......@@ -52,7 +52,7 @@ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
try:
import tensorflow as tf # noqa
_version = tf.__version__.split('.')
assert int(_version[0]) >= 1 and int(_version[1]) >= 3, "TF>=1.3 is required!"
assert (int(_version[0]), int(_version[1])) >= (1, 3), "TF>=1.3 is required!"
_HAS_TF = True
except ImportError:
print("Failed to import tensorflow.")
......
......@@ -4,7 +4,7 @@
import re
import six
import tensorflow as tf
from ..compat import tfv1 as tf # this should be avoided first in model code
from tensorflow.python.training import moving_averages
from ..tfutils.collection import backup_collection, restore_collection
......
......@@ -2,7 +2,7 @@
# File: conv2d.py
import tensorflow as tf
from ..compat import tfv1 as tf # this should be avoided first in model code
from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import get_data_format, shape2d, shape4d, log_once
......
......@@ -3,7 +3,7 @@
import numpy as np
import tensorflow as tf
from ..compat import tfv1 as tf # this should be avoided first in model code
from ..tfutils.common import get_tf_version_tuple
from .common import VariableHolder, layer_register
......
......@@ -2,7 +2,7 @@
# File: layer_norm.py
import tensorflow as tf
from ..compat import tfv1 as tf # this should be avoided first in model code
from ..utils.argtools import get_data_format
from .common import VariableHolder, layer_register
......
......@@ -2,7 +2,7 @@
# File: pool.py
import numpy as np
import tensorflow as tf
from ..compat import tfv1 as tf # this should be avoided first in model code
from ..utils.argtools import get_data_format, shape2d
from ..utils.develop import log_deprecated
......
......@@ -8,6 +8,7 @@ from functools import wraps
import six
import tensorflow as tf
from ..compat import tfv1
from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str
from ..utils import logger
......@@ -117,7 +118,7 @@ def layer_register(
# del actual_args[k]
if name is not None: # use scope
with tf.variable_scope(name) as scope:
with tfv1.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
......
......@@ -5,6 +5,7 @@
import re
import tensorflow as tf
from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
......@@ -60,13 +61,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
# If vars are shared, regularize all of them
# If vars are replicated, only regularize those in the current tower
if ctx.has_own_variables:
params = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
params = ctx.get_collection_in_tower(tfv1.GraphKeys.TRAINABLE_VARIABLES)
else:
params = tf.trainable_variables()
params = tfv1.trainable_variables()
names = []
with tf.name_scope(name + '_internals'):
with tfv1.name_scope(name + '_internals'):
costs = []
for p in params:
para_name = p.op.name
......@@ -119,9 +120,9 @@ def regularize_cost_from_collection(name='regularize_cost'):
# NOTE: this collection doesn't always grow with towers.
# It only grows with actual variable creation, but not get_variable call.
if ctx.has_own_variables: # be careful of the first tower (name='')
losses = ctx.get_collection_in_tower(tf.GraphKeys.REGULARIZATION_LOSSES)
losses = ctx.get_collection_in_tower(tfv1.GraphKeys.REGULARIZATION_LOSSES)
else:
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
losses = tfv1.get_collection(tfv1.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0:
logger.info("regularize_cost_from_collection() found {} regularizers "
"in REGULARIZATION_LOSSES collection.".format(len(losses)))
......
......@@ -5,7 +5,8 @@
from contextlib import contextmanager
from copy import copy
import six
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils.argtools import memoized
......
......@@ -5,12 +5,13 @@
import tensorflow as tf
from six.moves import map
from ..compat import tfv1
from ..utils.argtools import graph_memoized
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
'get_tf_version_tuple'
'get_tf_version_tuple',
# 'get_op_tensor_name',
# 'get_tensors_by_names',
# 'get_op_or_tensor_by_name',
......@@ -30,7 +31,7 @@ def get_default_sess_config(mem_fraction=0.99):
Returns:
tf.ConfigProto: the config to use.
"""
conf = tf.ConfigProto()
conf = tfv1.ConfigProto()
conf.allow_soft_placement = True
# conf.log_device_placement = True
......@@ -64,9 +65,9 @@ def get_global_step_var():
Returns:
tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
"""
scope = tf.VariableScope(reuse=False, name='') # the root vs
with tf.variable_scope(scope):
var = tf.train.get_or_create_global_step()
scope = tfv1.VariableScope(reuse=False, name='') # the root vs
with tfv1.variable_scope(scope):
var = tfv1.train.get_or_create_global_step()
return var
......@@ -78,8 +79,8 @@ def get_global_step_value():
Has to be called under a default session.
"""
return tf.train.global_step(
tf.get_default_session(),
return tfv1.train.global_step(
tfv1.get_default_session(),
get_global_step_var())
......@@ -108,7 +109,7 @@ def get_tensors_by_names(names):
names (list):
"""
ret = []
G = tf.get_default_graph()
G = tfv1.get_default_graph()
for n in names:
opn, varn = get_op_tensor_name(n)
ret.append(G.get_tensor_by_name(varn))
......@@ -125,7 +126,7 @@ def get_op_or_tensor_by_name(name):
Raises:
KeyError, if the name doesn't exist
"""
G = tf.get_default_graph()
G = tfv1.get_default_graph()
def f(n):
if len(n) >= 3 and n[-2] == ':':
......@@ -140,7 +141,7 @@ def get_op_or_tensor_by_name(name):
def gpu_available_in_session():
sess = tf.get_default_session()
sess = tfv1.get_default_session()
for dev in sess.list_devices():
if dev.device_type.lower() == 'gpu':
return True
......@@ -152,17 +153,3 @@ def get_tf_version_tuple():
Return TensorFlow version as a 2-element tuple (for comparison).
"""
return tuple(map(int, tf.__version__.split('.')[:2]))
def is_tf2():
try:
from tensorflow.python import tf2
return tf2.enabled()
except Exception:
return False
if is_tf2():
tfv1 = tf.compat.v1
else:
tfv1 = tf
import tensorflow as tf
from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..utils.argtools import graph_memoized
......@@ -33,6 +32,7 @@ def dependency_of_targets(targets, op):
op = op.op
assert isinstance(op, tf.Operation), op
from tensorflow.contrib.graph_editor import get_backward_walk_ops
# alternative implementation can use graph_util.extract_sub_graph
dependent_ops = get_backward_walk_ops(targets, control_inputs=True)
return op in dependent_ops
......
......@@ -2,7 +2,7 @@
# File: model_utils.py
# Author: tensorpack contributors
import tensorflow as tf
from ..compat import tfv1 as tf
from tabulate import tabulate
from termcolor import colored
......
......@@ -5,7 +5,8 @@
from contextlib import contextmanager
import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple, tfv1
from ..tfutils.common import get_tf_version_tuple
from ..compat import tfv1
from ..utils.develop import HIDE_DOC
from .gradproc import FilterNoneGrad, GradientProcessor
......
......@@ -4,8 +4,8 @@
import functools
from contextlib import contextmanager
import tensorflow as tf
from ..compat import tfv1 as tf
from ..utils.argtools import graph_memoized
from .common import get_tf_version_tuple
......
......@@ -2,10 +2,7 @@
# File: sesscreate.py
import tensorflow as tf
from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..tfutils.common import tfv1
from ..compat import tfv1 as tf, is_tfv2
from ..utils import logger
from .common import get_default_sess_config
......@@ -20,7 +17,7 @@ A SessionCreator should:
"""
class NewSessionCreator(tfv1.train.SessionCreator):
class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', config=None):
"""
Args:
......@@ -59,12 +56,16 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return False
def run(op):
deps = get_backward_walk_ops(op, control_inputs=True)
for dep_op in deps:
if blocking_op(dep_op):
logger.warn(
"Initializer '{}' depends on a blocking op '{}'. This initializer is likely to hang!".format(
op.name, dep_op.name))
if not is_tfv2():
from tensorflow.contrib.graph_editor import get_backward_walk_ops
deps = get_backward_walk_ops(op, control_inputs=True)
for dep_op in deps:
if blocking_op(dep_op):
logger.warn(
"Initializer '{}' depends on a blocking op '{}'. "
"This initializer is likely to hang!".format(
op.name, dep_op.name))
sess.run(op)
run(tf.global_variables_initializer())
......@@ -73,7 +74,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return sess
class ReuseSessionCreator(tfv1.train.SessionCreator):
class ReuseSessionCreator(tf.train.SessionCreator):
"""
Returns an existing session.
"""
......@@ -88,7 +89,7 @@ class ReuseSessionCreator(tfv1.train.SessionCreator):
return self.sess
class SessionCreatorAdapter(tfv1.train.SessionCreator):
class SessionCreatorAdapter(tf.train.SessionCreator):
"""
Apply a function on the output of a SessionCreator. Can be used to create a debug session.
"""
......
......@@ -5,10 +5,10 @@
import re
from contextlib import contextmanager
import six
import tensorflow as tf
from six.moves import range
from tensorflow.python.training import moving_averages
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from ..compat import tfv1
from ..utils.develop import deprecated
__all__ = ['print_stat', 'rms']
......@@ -30,7 +31,7 @@ def rms(x, name=None):
"""
if name is None:
name = x.op.name + '/rms'
with tf.name_scope(None): # name already contains the scope
with tfv1.name_scope(None): # name already contains the scope
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
......
......@@ -4,9 +4,10 @@
from abc import ABCMeta, abstractmethod, abstractproperty
import six
import tensorflow as tf
from six.moves import zip
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.develop import HIDE_DOC
......
......@@ -3,8 +3,8 @@
# Credit: Qinyao He
from contextlib import contextmanager
import tensorflow as tf
from ..compat import tfv1 as tf
from .common import get_tf_version_tuple
__all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
......
......@@ -8,6 +8,7 @@ import six
import tensorflow as tf
from six.moves import range
from ..compat import tfv1
from ..callbacks import Callback, Callbacks, Monitors, MonitorBase
from ..callbacks.steps import MaintainStepCounter
from ..tfutils import get_global_step_value
......@@ -222,7 +223,7 @@ class Trainer(object):
session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit):
"""
assert isinstance(session_creator, tf.train.SessionCreator), session_creator
assert isinstance(session_creator, tfv1.train.SessionCreator), session_creator
assert isinstance(session_init, SessionInit), session_init
session_init._setup_graph()
......@@ -250,7 +251,7 @@ class Trainer(object):
which can be useful when the training is not done by a single `train_op`.
"""
hooks = self._callbacks.get_hooks()
self.hooked_sess = tf.train.MonitoredSession(
self.hooked_sess = tfv1.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@call_only_once
......
# -*- coding: utf-8 -*-
# File: interface.py
import tensorflow as tf
from ..compat import tfv1
from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput
from ..utils import logger
from ..compat import is_tfv2
from .config import TrainConfig
from .tower import SingleCostTrainer
from .trainers import SimpleTrainer
......@@ -71,6 +71,9 @@ def launch_train_with_config(config, trainer):
launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
"""
if is_tfv2():
tfv1.disable_eager_execution()
assert isinstance(trainer, SingleCostTrainer), trainer
assert isinstance(config, TrainConfig), config
assert config.model is not None
......@@ -99,7 +102,7 @@ def launch_train_with_config(config, trainer):
def _check_unused_regularization():
coll = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
coll = tfv1.get_collection(tfv1.GraphKeys.REGULARIZATION_LOSSES)
unconsumed_reg = []
for c in coll:
if len(c.consumers()) == 0:
......
......@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
import six
import tensorflow as tf
from ..compat import tfv1, is_tfv2
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
from ..tfutils.gradproc import FilterNoneGrad
......@@ -126,7 +127,7 @@ class TowerTrainer(Trainer):
input.setup(self.inputs_desc)
vs_name = self._vs_name_for_predictor(device_id)
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
with tfv1.variable_scope(tfv1.get_variable_scope(), reuse=True), \
tf.device(device), PredictTowerContext(
tower_name, vs_name=vs_name):
logger.info("Building graph for predict tower '{}' on device {} {}...".format(
......@@ -254,15 +255,19 @@ class SingleCostTrainer(TowerTrainer):
return None # this is the tower function, could be called for inference
if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
varlist = ctx.get_collection_in_tower(tfv1.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
varlist = tfv1.trainable_variables()
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=self.GATE_GRADIENTS,
colocate_gradients_with_ops=self.COLOCATE_GRADIENTS_WITH_OPS,
aggregation_method=self.AGGREGATION_METHOD)
if is_tfv2() and isinstance(opt, tf.optimizers.Optimizer):
grads = opt.get_gradients(cost, varlist)
grads = list(zip(grads, varlist))
else:
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=self.GATE_GRADIENTS,
colocate_gradients_with_ops=self.COLOCATE_GRADIENTS_WITH_OPS,
aggregation_method=self.AGGREGATION_METHOD)
grads = FilterNoneGrad().process(grads)
return grads
......
......@@ -52,7 +52,7 @@ def graph_memoized(func):
"""
# TODO it keeps the graph alive
import tensorflow as tf
from ..compat import tfv1
GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__'
@memoized
......@@ -63,7 +63,7 @@ def graph_memoized(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
assert GRAPH_ARG_NAME not in kwargs, "No Way!!"
graph = tf.get_default_graph()
graph = tfv1.get_default_graph()
kwargs[GRAPH_ARG_NAME] = graph
return func_with_graph_arg(*args, **kwargs)
return wrapper
......
......@@ -5,7 +5,7 @@ ignore = E265,E741,E742,E743,W504,W605
exclude = .git,
__init__.py,
setup.py,
tensorpack/train/eager.py,
tensorpack/compat/*,
docs,
examples,
docs/conf.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment