Commit e2eba742 authored by Yuxin Wu's avatar Yuxin Wu

cleanup deprecations

parent e0a7e8f9
...@@ -377,8 +377,6 @@ _DEPRECATED_NAMES = set([ ...@@ -377,8 +377,6 @@ _DEPRECATED_NAMES = set([
'dump_dataflow_to_process_queue', 'dump_dataflow_to_process_queue',
'DistributedTrainerReplicated', 'DistributedTrainerReplicated',
'DistributedTrainerParameterServer', 'DistributedTrainerParameterServer',
'InputDesc',
'inputs_desc',
'Augmentor', 'Augmentor',
"get_model_loader", "get_model_loader",
......
...@@ -13,7 +13,6 @@ from termcolor import colored ...@@ -13,7 +13,6 @@ from termcolor import colored
from ..utils import logger from ..utils import logger
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow
try: try:
...@@ -622,7 +621,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -622,7 +621,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
because it does not make sense to stop the iteration anywhere. because it does not make sense to stop the iteration anywhere.
""" """
def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None, nr_reuse=None): def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
...@@ -633,11 +632,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -633,11 +632,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
datapoints were produced from the given dataflow. Frequent shuffle on large buffer datapoints were produced from the given dataflow. Frequent shuffle on large buffer
may affect speed, but infrequent shuffle may not provide enough randomness. may affect speed, but infrequent shuffle may not provide enough randomness.
Defaults to buffer_size / 3 Defaults to buffer_size / 3
nr_reuse: deprecated name for num_reuse
""" """
if nr_reuse is not None:
log_deprecated("LocallyShuffleData(nr_reuse=...)", "Renamed to 'num_reuse'.", "2020-01-01")
num_reuse = nr_reuse
ProxyDataFlow.__init__(self, ds) ProxyDataFlow.__init__(self, ds)
self.q = deque(maxlen=buffer_size) self.q = deque(maxlen=buffer_size)
if shuffle_interval is None: if shuffle_interval is None:
......
...@@ -16,7 +16,6 @@ import zmq ...@@ -16,7 +16,6 @@ import zmq
from six.moves import queue, range from six.moves import queue, range
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.concurrency import ( from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal) StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal)
from ..utils.serialize import dumps_once as dumps, loads_once as loads from ..utils.serialize import dumps_once as dumps, loads_once as loads
...@@ -193,24 +192,14 @@ class MultiProcessRunner(ProxyDataFlow): ...@@ -193,24 +192,14 @@ class MultiProcessRunner(ProxyDataFlow):
for dp in self.ds: for dp in self.ds:
self.queue.put(dp) self.queue.put(dp)
def __init__(self, ds, num_prefetch=None, num_proc=None, nr_prefetch=None, nr_proc=None): def __init__(self, ds, num_prefetch, num_proc):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
num_prefetch (int): size of the queue to hold prefetched datapoints. num_prefetch (int): size of the queue to hold prefetched datapoints.
Required. Required.
num_proc (int): number of processes to use. Required. num_proc (int): number of processes to use. Required.
nr_prefetch, nr_proc: deprecated argument names
""" """
if nr_prefetch is not None:
log_deprecated("MultiProcessRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_proc is not None:
log_deprecated("MultiProcessRunner(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if num_prefetch is None or num_proc is None:
raise TypeError("Missing argument num_prefetch or num_proc in MultiProcessRunner!")
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods
if os.name == 'nt': if os.name == 'nt':
logger.warn("MultiProcessRunner does support Windows. \ logger.warn("MultiProcessRunner does support Windows. \
...@@ -333,17 +322,13 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow): ...@@ -333,17 +322,13 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
socket.close(0) socket.close(0)
context.destroy(0) context.destroy(0)
def __init__(self, ds, num_proc=1, hwm=50, nr_proc=None): def __init__(self, ds, num_proc=1, hwm=50):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
num_proc (int): number of processes to use. num_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver. hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
nr_proc: deprecated
""" """
if nr_proc is not None:
log_deprecated("MultiProcessRunnerZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
super(MultiProcessRunnerZMQ, self).__init__() super(MultiProcessRunnerZMQ, self).__init__()
self.ds = ds self.ds = ds
...@@ -443,7 +428,7 @@ class MultiThreadRunner(DataFlow): ...@@ -443,7 +428,7 @@ class MultiThreadRunner(DataFlow):
finally: finally:
self.stop() self.stop()
def __init__(self, get_df, num_prefetch=None, num_thread=None, nr_prefetch=None, nr_thread=None): def __init__(self, get_df, num_prefetch, num_thread):
""" """
Args: Args:
get_df ( -> DataFlow): a callable which returns a DataFlow. get_df ( -> DataFlow): a callable which returns a DataFlow.
...@@ -452,17 +437,7 @@ class MultiThreadRunner(DataFlow): ...@@ -452,17 +437,7 @@ class MultiThreadRunner(DataFlow):
unless your dataflow is stateless. unless your dataflow is stateless.
num_prefetch (int): size of the queue num_prefetch (int): size of the queue
num_thread (int): number of threads num_thread (int): number of threads
nr_prefetch, nr_thread: deprecated names
""" """
if nr_prefetch is not None:
log_deprecated("MultiThreadRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_thread is not None:
log_deprecated("MultiThreadRunner(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if num_prefetch is None or num_thread is None:
raise TypeError("Missing argument num_prefetch or num_thread in MultiThreadRunner!")
assert num_thread > 0, num_thread assert num_thread > 0, num_thread
assert num_prefetch > 0, num_prefetch assert num_prefetch > 0, num_prefetch
self.num_thread = num_thread self.num_thread = num_thread
......
...@@ -10,7 +10,6 @@ from six.moves import queue ...@@ -10,7 +10,6 @@ from six.moves import queue
from ..utils.concurrency import StoppableThread, enable_death_signal from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps_once as dumps, loads_once as loads from ..utils.serialize import dumps_once as dumps, loads_once as loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData, BatchData from .common import RepeatedData, BatchData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
...@@ -143,7 +142,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -143,7 +142,7 @@ class MultiThreadMapData(_ParallelMapData):
finally: finally:
self.stop() self.stop()
def __init__(self, ds, num_thread=None, map_func=None, buffer_size=200, strict=False, nr_thread=None): def __init__(self, ds, num_thread=None, map_func=None, *, buffer_size=200, strict=False):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
...@@ -152,12 +151,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -152,12 +151,7 @@ class MultiThreadMapData(_ParallelMapData):
discard/skip the datapoint. discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
nr_thread: deprecated name
""" """
if nr_thread is not None:
log_deprecated("MultiThreadMapData(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if strict: if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints # In strict mode, buffer size cannot be larger than the total number of datapoints
try: try:
...@@ -255,7 +249,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -255,7 +249,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = self.map_func(dp) dp = self.map_func(dp)
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
def __init__(self, ds, num_proc=None, map_func=None, buffer_size=200, strict=False, nr_proc=None): def __init__(self, ds, num_proc=None, map_func=None, *, buffer_size=200, strict=False):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
...@@ -264,11 +258,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -264,11 +258,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
discard/skip the datapoint. discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
nr_proc: deprecated name
""" """
if nr_proc is not None:
log_deprecated("MultiProcessMapDataZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if strict: if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints # In strict mode, buffer size cannot be larger than the total number of datapoints
try: try:
......
...@@ -2,35 +2,7 @@ ...@@ -2,35 +2,7 @@
# File: model_desc.py # File: model_desc.py
from collections import namedtuple
import tensorflow as tf
from ..utils.develop import log_deprecated
from ..train.model_desc import ModelDesc, ModelDescBase # kept for BC # noqa from ..train.model_desc import ModelDesc, ModelDescBase # kept for BC # noqa
__all__ = ['InputDesc'] __all__ = []
class InputDesc(
namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
"""
An equivalent of `tf.TensorSpec`.
History: this concept is used to represent metadata about the inputs,
which can be later used to build placeholders or other types of input source.
It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
was introduced in TensorFlow.
Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
"""
def __new__(cls, type, shape, name):
"""
Args:
type (tf.DType):
shape (tuple):
name (str):
"""
log_deprecated("InputDesc", "Use tf.TensorSpec instead!", "2020-03-01")
assert isinstance(type, tf.DType), type
return tf.TensorSpec(shape=shape, dtype=type, name=name)
...@@ -10,7 +10,6 @@ from ..tfutils.sessinit import JustCurrentSession, SessionInit ...@@ -10,7 +10,6 @@ from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TowerFunc from ..tfutils.tower import TowerFunc
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
...@@ -28,7 +27,6 @@ class PredictConfig(object): ...@@ -28,7 +27,6 @@ class PredictConfig(object):
session_init=None, session_init=None,
return_input=False, return_input=False,
create_graph=True, create_graph=True,
inputs_desc=None
): ):
""" """
Users need to provide enough arguments to create a tower function, Users need to provide enough arguments to create a tower function,
...@@ -69,18 +67,12 @@ class PredictConfig(object): ...@@ -69,18 +67,12 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`. return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized. when predictor is first initialized.
inputs_desc (list[tf.TensorSpec]): old (deprecated) name for `input_signature`.
""" """
def assert_type(v, tp, name): def assert_type(v, tp, name):
assert isinstance(v, tp), \ assert isinstance(v, tp), \
"Argument '{}' has to be type '{}', but an object of type '{}' found.".format( "Argument '{}' has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__) name, tp.__name__, v.__class__.__name__)
if inputs_desc is not None:
log_deprecated("PredictConfig(inputs_desc)", "Use input_signature instead!", "2020-03-01")
assert input_signature is None, "Cannot set both inputs_desc and input_signature!"
input_signature = inputs_desc
if model is not None: if model is not None:
assert_type(model, ModelDescBase, 'model') assert_type(model, ModelDescBase, 'model')
assert input_signature is None and tower_func is None assert input_signature is None and tower_func is None
...@@ -120,8 +112,6 @@ class PredictConfig(object): ...@@ -120,8 +112,6 @@ class PredictConfig(object):
self.return_input = bool(return_input) self.return_input = bool(return_input)
self.create_graph = bool(create_graph) self.create_graph = bool(create_graph)
self.inputs_desc = input_signature # TODO a little bit of compatibility
def _maybe_create_graph(self): def _maybe_create_graph(self):
if self.create_graph: if self.create_graph:
return tf.Graph() return tf.Graph()
......
...@@ -8,7 +8,7 @@ import six ...@@ -8,7 +8,7 @@ import six
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils.develop import HIDE_DOC, log_deprecated from ..utils.develop import HIDE_DOC
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name from .common import get_op_or_tensor_by_name, get_op_tensor_name
...@@ -305,11 +305,6 @@ class TowerFunc(object): ...@@ -305,11 +305,6 @@ class TowerFunc(object):
def input_signature(self): def input_signature(self):
return self._input_signature return self._input_signature
@property
def inputs_desc(self):
log_deprecated("TowerFunc.inputs_desc", "Use .input_signature instead", "2020-03-01")
return self._input_signature
TowerFuncWrapper = TowerFunc TowerFuncWrapper = TowerFunc
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import tensorflow as tf import tensorflow as tf
from ..utils.develop import log_deprecated, HIDE_DOC
from ..utils.argtools import memoized_method from ..utils.argtools import memoized_method
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
...@@ -27,11 +26,6 @@ class ModelDescBase(object): ...@@ -27,11 +26,6 @@ class ModelDescBase(object):
together define a tower function. together define a tower function.
""" """
@HIDE_DOC
def get_inputs_desc(self):
log_deprecated("ModelDesc.get_inputs_desc", "Use get_input_signature instead!", "2020-03-01")
return self.get_input_signature()
@memoized_method @memoized_method
def get_input_signature(self): def get_input_signature(self):
""" """
......
...@@ -12,7 +12,7 @@ from ..tfutils.gradproc import FilterNoneGrad ...@@ -12,7 +12,7 @@ from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import PredictTowerContext, TowerFunc, get_current_tower_context from ..tfutils.tower import PredictTowerContext, TowerFunc, get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC, log_deprecated from ..utils.develop import HIDE_DOC
from .base import Trainer from .base import Trainer
__all__ = ['SingleCostTrainer', 'TowerTrainer'] __all__ = ['SingleCostTrainer', 'TowerTrainer']
...@@ -56,11 +56,6 @@ class TowerTrainer(Trainer): ...@@ -56,11 +56,6 @@ class TowerTrainer(Trainer):
def tower_func(self, val): def tower_func(self, val):
self._set_tower_func(val) self._set_tower_func(val)
@property
def inputs_desc(self):
log_deprecated("TowerTrainer.inputs_desc", "Use .input_signature instead!", "2020-03-01")
return self.input_signature
@property @property
def input_signature(self): def input_signature(self):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment