Commit 047579df authored by Yuxin Wu's avatar Yuxin Wu

TowerFuncWrapper -> TowerFunc

parent 50ff9036
...@@ -384,6 +384,7 @@ _DEPRECATED_NAMES = set([ ...@@ -384,6 +384,7 @@ _DEPRECATED_NAMES = set([
'get_nr_gpu', 'get_nr_gpu',
'TrainingMonitor', 'TrainingMonitor',
'PeakMemoryTracker', 'PeakMemoryTracker',
'TowerFuncWrapper',
'PrefetchData', 'PrefetchData',
'MultiProcessPrefetchData', 'MultiProcessPrefetchData',
...@@ -391,7 +392,7 @@ _DEPRECATED_NAMES = set([ ...@@ -391,7 +392,7 @@ _DEPRECATED_NAMES = set([
'MultiThreadPrefetchData', 'MultiThreadPrefetchData',
# deprecated or renamed symbolic code # deprecated or renamed symbolic code
'Deconv2D', 'psnr', 'Deconv2D',
# shouldn't appear in doc: # shouldn't appear in doc:
'l2_regularizer', 'l1_regularizer', 'l2_regularizer', 'l1_regularizer',
......
...@@ -8,7 +8,7 @@ import tensorflow as tf ...@@ -8,7 +8,7 @@ import tensorflow as tf
from tensorpack import BatchNorm, DataFlow, ModelDescBase, StagingInput, TowerTrainer, argscope from tensorpack import BatchNorm, DataFlow, ModelDescBase, StagingInput, TowerTrainer, argscope
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper from tensorpack.tfutils.tower import TowerContext, TowerFunc
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized_method from tensorpack.utils.argtools import memoized_method
...@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer): ...@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK. not needed. Just calling model.build_graph directly is OK.
""" """
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.inputs()) self.tower_func = TowerFunc(model.build_graph, model.inputs())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
...@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer): ...@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
model.build_graph(*inputs) model.build_graph(*inputs)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_input_signature()) self.tower_func = TowerFunc(get_cost, model.get_input_signature())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers( cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)), list(range(num_gpu)),
...@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer):
self.register_callback(cbs) self.register_callback(cbs)
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.inputs()) self.tower_func = TowerFunc(model.build_graph, model.inputs())
with TowerContext('', is_training=True), \ with TowerContext('', is_training=True), \
argscope(BatchNorm, ema_update='internal'): argscope(BatchNorm, ema_update='internal'):
# should not hook the EMA updates to both train_op, it will hurt training speed. # should not hook the EMA updates to both train_op, it will hurt training speed.
......
...@@ -120,7 +120,7 @@ if __name__ == '__main__': ...@@ -120,7 +120,7 @@ if __name__ == '__main__':
if get_num_gpu() <= 1: if get_num_gpu() <= 1:
# single GPU: # single GPU:
launch_train_with_config(cfg, QueueInputTrainer()) launch_train_with_config(cfg, SimpleTrainer())
else: else:
# multi GPU: # multi GPU:
launch_train_with_config(cfg, SyncMultiGPUTrainerParameterServer(2)) launch_train_with_config(cfg, SyncMultiGPUTrainerParameterServer(2))
......
...@@ -114,7 +114,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -114,7 +114,7 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances. infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used. different one if multiple InferenceRunner are used.
tower_func (tfutils.TowerFuncWrapper or None): the tower function to be used to build the graph. tower_func (tfutils.TowerFunc or None): the tower function to be used to build the graph.
By defaults to call `trainer.tower_func` under a `training=False` TowerContext, By defaults to call `trainer.tower_func` under a `training=False` TowerContext,
but you can change it to a different tower function but you can change it to a different tower function
if you need to inference with several different graphs. if you need to inference with several different graphs.
...@@ -196,7 +196,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -196,7 +196,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
gpus (int or list[int]): #gpus, or list of GPU id gpus (int or list[int]): #gpus, or list of GPU id
tower_name (str): the name scope of the tower to build. Need to set a tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used. different one if multiple InferenceRunner are used.
tower_func (tfutils.TowerFuncWrapper or None): the tower function to be used to build the graph. tower_func (tfutils.TowerFunc or None): the tower function to be used to build the graph.
The tower function will be called under a `training=False` TowerContext. The tower function will be called under a `training=False` TowerContext.
The default is `trainer.tower_func`, The default is `trainer.tower_func`,
but you can change it to a different tower function but you can change it to a different tower function
......
...@@ -6,7 +6,6 @@ from collections import namedtuple ...@@ -6,7 +6,6 @@ from collections import namedtuple
import tensorflow as tf import tensorflow as tf
from ..utils.argtools import memoized_method from ..utils.argtools import memoized_method
from ..utils.develop import deprecated
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..compat import backport_tensor_spec, tfv1 from ..compat import backport_tensor_spec, tfv1
...@@ -174,7 +173,3 @@ class ModelDesc(ModelDescBase): ...@@ -174,7 +173,3 @@ class ModelDesc(ModelDescBase):
A subclass is expected to implement this method. A subclass is expected to implement this method.
""" """
raise NotImplementedError() raise NotImplementedError()
@deprecated("Just use `build_graph` instead!")
def _build_graph_get_cost(self, *inputs):
return self.build_graph(*inputs)
...@@ -9,7 +9,7 @@ from ..graph_builder import ModelDescBase ...@@ -9,7 +9,7 @@ from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import JustCurrentSession, SessionInit from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TowerFuncWrapper from ..tfutils.tower import TowerFunc
from ..utils import logger from ..utils import logger
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
...@@ -36,7 +36,7 @@ class PredictConfig(object): ...@@ -36,7 +36,7 @@ class PredictConfig(object):
This can be provided in the following ways: This can be provided in the following ways:
1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself. 1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself.
2. `tower_func`: a :class:`tfutils.TowerFuncWrapper` instance. 2. `tower_func`: a :class:`tfutils.TowerFunc` instance.
Provide a tower function instance directly. Provide a tower function instance directly.
3. `tower_func`: a symbolic function and `input_signature`: the signature of the function. 3. `tower_func`: a symbolic function and `input_signature`: the signature of the function.
Provide both a function and its signature. Provide both a function and its signature.
...@@ -52,8 +52,8 @@ class PredictConfig(object): ...@@ -52,8 +52,8 @@ class PredictConfig(object):
Args: Args:
model (ModelDescBase): to be used to construct a tower function. model (ModelDescBase): to be used to construct a tower function.
tower_func: a callable which takes input tensors (by positional args) and construct a tower. tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance. or a :class:`tfutils.TowerFunc` instance.
input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFuncWrapper), input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFunc),
this describes the list of inputs it takes. this describes the list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match input_signature. input_names (list): a list of input tensor names. Defaults to match input_signature.
...@@ -85,13 +85,13 @@ class PredictConfig(object): ...@@ -85,13 +85,13 @@ class PredictConfig(object):
assert_type(model, ModelDescBase, 'model') assert_type(model, ModelDescBase, 'model')
assert input_signature is None and tower_func is None assert input_signature is None and tower_func is None
self.input_signature = model.get_input_signature() self.input_signature = model.get_input_signature()
self.tower_func = TowerFuncWrapper(model.build_graph, self.input_signature) self.tower_func = TowerFunc(model.build_graph, self.input_signature)
else: else:
if isinstance(tower_func, TowerFuncWrapper): if isinstance(tower_func, TowerFunc):
input_signature = tower_func.input_signature input_signature = tower_func.input_signature
assert input_signature is not None and tower_func is not None assert input_signature is not None and tower_func is not None
self.input_signature = input_signature self.input_signature = input_signature
self.tower_func = TowerFuncWrapper(tower_func, input_signature) self.tower_func = TowerFunc(tower_func, input_signature)
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import tensorflow as tf import tensorflow as tf
from ..compat import tfv1 from ..compat import tfv1
from ..utils.develop import deprecated
__all__ = ['print_stat', 'rms'] __all__ = ['print_stat', 'rms']
...@@ -37,7 +36,6 @@ def rms(x, name=None): ...@@ -37,7 +36,6 @@ def rms(x, name=None):
# don't hurt to leave it here # don't hurt to leave it here
@deprecated("Please implement it by yourself.", "2018-04-28")
def psnr(prediction, ground_truth, maxp=None, name='psnr'): def psnr(prediction, ground_truth, maxp=None, name='psnr'):
"""`Peak Signal to Noise Ratio <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_. """`Peak Signal to Noise Ratio <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_.
......
...@@ -15,7 +15,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY ...@@ -15,7 +15,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name from .common import get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'BaseTowerContext', 'TowerContext', 'TowerFuncWrapper', __all__ = ['get_current_tower_context', 'BaseTowerContext', 'TowerContext',
'TowerFuncWrapper', 'TowerFunc',
'TowerTensorHandle', 'TowerTensorHandles'] 'TowerTensorHandle', 'TowerTensorHandles']
_CurrentTowerContext = None _CurrentTowerContext = None
...@@ -245,9 +246,9 @@ def TowerContext(tower_name, is_training, vs_name=''): ...@@ -245,9 +246,9 @@ def TowerContext(tower_name, is_training, vs_name=''):
return PredictTowerContext(tower_name, vs_name=vs_name) return PredictTowerContext(tower_name, vs_name=vs_name)
class TowerFuncWrapper(object): class TowerFunc(object):
""" """
A wrapper around a tower function (see A tower function (see
[tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)). [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)).
It keeps track of the name scope, variable scope and input/output tensors It keeps track of the name scope, variable scope and input/output tensors
each time the function is called. each time the function is called.
...@@ -279,10 +280,10 @@ class TowerFuncWrapper(object): ...@@ -279,10 +280,10 @@ class TowerFuncWrapper(object):
def __new__(cls, tower_fn, _): def __new__(cls, tower_fn, _):
# to avoid double-wrapping a function # to avoid double-wrapping a function
if isinstance(tower_fn, TowerFuncWrapper): if isinstance(tower_fn, TowerFunc):
return tower_fn return tower_fn
else: else:
return super(TowerFuncWrapper, cls).__new__(cls) return super(TowerFunc, cls).__new__(cls)
def __call__(self, *args): def __call__(self, *args):
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -311,6 +312,9 @@ class TowerFuncWrapper(object): ...@@ -311,6 +312,9 @@ class TowerFuncWrapper(object):
return self._input_signature return self._input_signature
TowerFuncWrapper = TowerFunc
class TowerTensorHandles(object): class TowerTensorHandles(object):
""" """
Wrap a list of :class:`TowerTensorHandle`, Wrap a list of :class:`TowerTensorHandle`,
......
...@@ -176,9 +176,14 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -176,9 +176,14 @@ class AutoResumeTrainConfig(TrainConfig):
Note that the functionality requires the logging directory to obtain Note that the functionality requires the logging directory to obtain
necessary information from a previous run. necessary information from a previous run.
In some cases (e.g. when using Horovod), the directory is not If you have unconventional setup of logging directory, this class will not
available, or the directories are different for different workers, work for you, for example:
then this class may not function properly.
1. If you save the checkpoint to a different directory rather than the
logging directory.
2. If in distributed training the directory is not
available to every worker, or the directories are different for different workers.
""" """
def __init__(self, always_resume=True, **kwargs): def __init__(self, always_resume=True, **kwargs):
""" """
...@@ -189,7 +194,7 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -189,7 +194,7 @@ class AutoResumeTrainConfig(TrainConfig):
kwargs: same as in :class:`TrainConfig`. kwargs: same as in :class:`TrainConfig`.
Note: Note:
The main goal of this class is to let a training job to resume The main goal of this class is to let a training job resume
without changing any line of code or command line arguments. without changing any line of code or command line arguments.
So it's useful to let resume take priority over user-provided arguments sometimes. So it's useful to let resume take priority over user-provided arguments sometimes.
......
...@@ -85,7 +85,7 @@ def launch_train_with_config(config, trainer): ...@@ -85,7 +85,7 @@ def launch_train_with_config(config, trainer):
# This is the only place where the `ModelDesc` abstraction is useful. # This is the only place where the `ModelDesc` abstraction is useful.
# We should gradually stay away from this unuseful abstraction. # We should gradually stay away from this unuseful abstraction.
# TowerFuncWrapper is a better abstraction (similar to tf.defun in the future) # TowerFunc is a better abstraction (similar to tf.function in the future)
trainer.setup_graph( trainer.setup_graph(
model.get_input_signature(), input, model.get_input_signature(), input,
model.build_graph, model.get_optimizer) model.build_graph, model.get_optimizer)
......
...@@ -9,7 +9,7 @@ from ..compat import tfv1, is_tfv2 ...@@ -9,7 +9,7 @@ from ..compat import tfv1, is_tfv2
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import PredictTowerContext, TowerFuncWrapper, get_current_tower_context from ..tfutils.tower import PredictTowerContext, TowerFunc, get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
...@@ -38,13 +38,13 @@ class TowerTrainer(Trainer): ...@@ -38,13 +38,13 @@ class TowerTrainer(Trainer):
@call_only_once @call_only_once
def _set_tower_func(self, tower_func): def _set_tower_func(self, tower_func):
assert isinstance(tower_func, TowerFuncWrapper), tower_func assert isinstance(tower_func, TowerFunc), tower_func
self._tower_func = tower_func self._tower_func = tower_func
@property @property
def tower_func(self): def tower_func(self):
""" """
A :class:`TowerFuncWrapper` instance. A :class:`TowerFunc` instance.
See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer) See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)
for more information. for more information.
""" """
...@@ -215,7 +215,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -215,7 +215,7 @@ class SingleCostTrainer(TowerTrainer):
It must follows the `rules of tower function. It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_. <http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
""" """
get_cost_fn = TowerFuncWrapper(get_cost_fn, input_signature) get_cost_fn = TowerFunc(get_cost_fn, input_signature)
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
self.tower_func = get_cost_fn self.tower_func = get_cost_fn
......
...@@ -18,7 +18,7 @@ from ..tfutils.sesscreate import NewSessionCreator ...@@ -18,7 +18,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TrainTowerContext from ..tfutils.tower import TrainTowerContext
from ..utils import logger from ..utils import logger
from ..utils.argtools import map_arg from ..utils.argtools import map_arg
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC, deprecated
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
__all__ = ['NoOpTrainer', 'SimpleTrainer', __all__ = ['NoOpTrainer', 'SimpleTrainer',
...@@ -66,6 +66,7 @@ class NoOpTrainer(SimpleTrainer): ...@@ -66,6 +66,7 @@ class NoOpTrainer(SimpleTrainer):
# Only exists for type check & back-compatibility # Only exists for type check & back-compatibility
class QueueInputTrainer(SimpleTrainer): class QueueInputTrainer(SimpleTrainer):
@deprecated("SimpleTrainer is sufficient!", "2019-12-31")
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, QueueInput), input assert isinstance(input, QueueInput), input
return super(QueueInputTrainer, self)._setup_graph(input, get_cost_fn, get_opt_fn) return super(QueueInputTrainer, self)._setup_graph(input, get_cost_fn, get_opt_fn)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment