Commit e2eba742 authored by Yuxin Wu's avatar Yuxin Wu

cleanup deprecations

parent e0a7e8f9
......@@ -377,8 +377,6 @@ _DEPRECATED_NAMES = set([
'dump_dataflow_to_process_queue',
'DistributedTrainerReplicated',
'DistributedTrainerParameterServer',
'InputDesc',
'inputs_desc',
'Augmentor',
"get_model_loader",
......
......@@ -13,7 +13,6 @@ from termcolor import colored
from ..utils import logger
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow
try:
......@@ -622,7 +621,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
because it does not make sense to stop the iteration anywhere.
"""
def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None, nr_reuse=None):
def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None):
"""
Args:
ds (DataFlow): input DataFlow.
......@@ -633,11 +632,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
datapoints were produced from the given dataflow. Frequent shuffle on large buffer
may affect speed, but infrequent shuffle may not provide enough randomness.
Defaults to buffer_size / 3
nr_reuse: deprecated name for num_reuse
"""
if nr_reuse is not None:
log_deprecated("LocallyShuffleData(nr_reuse=...)", "Renamed to 'num_reuse'.", "2020-01-01")
num_reuse = nr_reuse
ProxyDataFlow.__init__(self, ds)
self.q = deque(maxlen=buffer_size)
if shuffle_interval is None:
......
......@@ -16,7 +16,6 @@ import zmq
from six.moves import queue, range
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal)
from ..utils.serialize import dumps_once as dumps, loads_once as loads
......@@ -193,24 +192,14 @@ class MultiProcessRunner(ProxyDataFlow):
for dp in self.ds:
self.queue.put(dp)
def __init__(self, ds, num_prefetch=None, num_proc=None, nr_prefetch=None, nr_proc=None):
def __init__(self, ds, num_prefetch, num_proc):
"""
Args:
ds (DataFlow): input DataFlow.
num_prefetch (int): size of the queue to hold prefetched datapoints.
Required.
num_proc (int): number of processes to use. Required.
nr_prefetch, nr_proc: deprecated argument names
"""
if nr_prefetch is not None:
log_deprecated("MultiProcessRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_proc is not None:
log_deprecated("MultiProcessRunner(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if num_prefetch is None or num_proc is None:
raise TypeError("Missing argument num_prefetch or num_proc in MultiProcessRunner!")
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods
if os.name == 'nt':
logger.warn("MultiProcessRunner does support Windows. \
......@@ -333,17 +322,13 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
socket.close(0)
context.destroy(0)
def __init__(self, ds, num_proc=1, hwm=50, nr_proc=None):
def __init__(self, ds, num_proc=1, hwm=50):
"""
Args:
ds (DataFlow): input DataFlow.
num_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
nr_proc: deprecated
"""
if nr_proc is not None:
log_deprecated("MultiProcessRunnerZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
super(MultiProcessRunnerZMQ, self).__init__()
self.ds = ds
......@@ -443,7 +428,7 @@ class MultiThreadRunner(DataFlow):
finally:
self.stop()
def __init__(self, get_df, num_prefetch=None, num_thread=None, nr_prefetch=None, nr_thread=None):
def __init__(self, get_df, num_prefetch, num_thread):
"""
Args:
get_df ( -> DataFlow): a callable which returns a DataFlow.
......@@ -452,17 +437,7 @@ class MultiThreadRunner(DataFlow):
unless your dataflow is stateless.
num_prefetch (int): size of the queue
num_thread (int): number of threads
nr_prefetch, nr_thread: deprecated names
"""
if nr_prefetch is not None:
log_deprecated("MultiThreadRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_thread is not None:
log_deprecated("MultiThreadRunner(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if num_prefetch is None or num_thread is None:
raise TypeError("Missing argument num_prefetch or num_thread in MultiThreadRunner!")
assert num_thread > 0, num_thread
assert num_prefetch > 0, num_prefetch
self.num_thread = num_thread
......
......@@ -10,7 +10,6 @@ from six.moves import queue
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps_once as dumps, loads_once as loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData, BatchData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
......@@ -143,7 +142,7 @@ class MultiThreadMapData(_ParallelMapData):
finally:
self.stop()
def __init__(self, ds, num_thread=None, map_func=None, buffer_size=200, strict=False, nr_thread=None):
def __init__(self, ds, num_thread=None, map_func=None, *, buffer_size=200, strict=False):
"""
Args:
ds (DataFlow): the dataflow to map
......@@ -152,12 +151,7 @@ class MultiThreadMapData(_ParallelMapData):
discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
nr_thread: deprecated name
"""
if nr_thread is not None:
log_deprecated("MultiThreadMapData(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
......@@ -255,7 +249,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)
def __init__(self, ds, num_proc=None, map_func=None, buffer_size=200, strict=False, nr_proc=None):
def __init__(self, ds, num_proc=None, map_func=None, *, buffer_size=200, strict=False):
"""
Args:
ds (DataFlow): the dataflow to map
......@@ -264,11 +258,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
nr_proc: deprecated name
"""
if nr_proc is not None:
log_deprecated("MultiProcessMapDataZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
......
......@@ -2,35 +2,7 @@
# File: model_desc.py
from collections import namedtuple
import tensorflow as tf
from ..utils.develop import log_deprecated
from ..train.model_desc import ModelDesc, ModelDescBase # kept for BC # noqa
__all__ = ['InputDesc']
class InputDesc(
namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
"""
An equivalent of `tf.TensorSpec`.
History: this concept is used to represent metadata about the inputs,
which can be later used to build placeholders or other types of input source.
It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
was introduced in TensorFlow.
Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
"""
def __new__(cls, type, shape, name):
"""
Args:
type (tf.DType):
shape (tuple):
name (str):
"""
log_deprecated("InputDesc", "Use tf.TensorSpec instead!", "2020-03-01")
assert isinstance(type, tf.DType), type
return tf.TensorSpec(shape=shape, dtype=type, name=name)
__all__ = []
......@@ -10,7 +10,6 @@ from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TowerFunc
from ..utils import logger
from ..utils.develop import log_deprecated
__all__ = ['PredictConfig']
......@@ -28,7 +27,6 @@ class PredictConfig(object):
session_init=None,
return_input=False,
create_graph=True,
inputs_desc=None
):
"""
Users need to provide enough arguments to create a tower function,
......@@ -69,18 +67,12 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
inputs_desc (list[tf.TensorSpec]): old (deprecated) name for `input_signature`.
"""
def assert_type(v, tp, name):
assert isinstance(v, tp), \
"Argument '{}' has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__)
if inputs_desc is not None:
log_deprecated("PredictConfig(inputs_desc)", "Use input_signature instead!", "2020-03-01")
assert input_signature is None, "Cannot set both inputs_desc and input_signature!"
input_signature = inputs_desc
if model is not None:
assert_type(model, ModelDescBase, 'model')
assert input_signature is None and tower_func is None
......@@ -120,8 +112,6 @@ class PredictConfig(object):
self.return_input = bool(return_input)
self.create_graph = bool(create_graph)
self.inputs_desc = input_signature # TODO a little bit of compatibility
def _maybe_create_graph(self):
if self.create_graph:
return tf.Graph()
......
......@@ -8,7 +8,7 @@ import six
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.develop import HIDE_DOC, log_deprecated
from ..utils.develop import HIDE_DOC
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name
......@@ -305,11 +305,6 @@ class TowerFunc(object):
def input_signature(self):
return self._input_signature
@property
def inputs_desc(self):
log_deprecated("TowerFunc.inputs_desc", "Use .input_signature instead", "2020-03-01")
return self._input_signature
TowerFuncWrapper = TowerFunc
......
......@@ -4,7 +4,6 @@
import tensorflow as tf
from ..utils.develop import log_deprecated, HIDE_DOC
from ..utils.argtools import memoized_method
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
......@@ -27,11 +26,6 @@ class ModelDescBase(object):
together define a tower function.
"""
@HIDE_DOC
def get_inputs_desc(self):
log_deprecated("ModelDesc.get_inputs_desc", "Use get_input_signature instead!", "2020-03-01")
return self.get_input_signature()
@memoized_method
def get_input_signature(self):
"""
......
......@@ -12,7 +12,7 @@ from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import PredictTowerContext, TowerFunc, get_current_tower_context
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC, log_deprecated
from ..utils.develop import HIDE_DOC
from .base import Trainer
__all__ = ['SingleCostTrainer', 'TowerTrainer']
......@@ -56,11 +56,6 @@ class TowerTrainer(Trainer):
def tower_func(self, val):
self._set_tower_func(val)
@property
def inputs_desc(self):
log_deprecated("TowerTrainer.inputs_desc", "Use .input_signature instead!", "2020-03-01")
return self.input_signature
@property
def input_signature(self):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment