Commit 377fba1f authored by Yuxin Wu's avatar Yuxin Wu

clean-up some deprecation

parent 73b63247
...@@ -353,7 +353,6 @@ def process_signature(app, what, name, obj, options, signature, ...@@ -353,7 +353,6 @@ def process_signature(app, what, name, obj, options, signature,
def autodoc_skip_member(app, what, name, obj, skip, options): def autodoc_skip_member(app, what, name, obj, skip, options):
if name in [ if name in [
'DistributedReplicatedTrainer',
'SingleCostFeedfreeTrainer', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'SimpleFeedfreeTrainer',
'FeedfreeTrainerBase', 'FeedfreeTrainerBase',
...@@ -369,7 +368,6 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -369,7 +368,6 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'GaussianDeform', 'GaussianDeform',
'dump_chkpt_vars', 'dump_chkpt_vars',
'VisualQA', 'VisualQA',
'ParamRestore',
'huber_loss' 'huber_loss'
]: ]:
return True return True
......
...@@ -8,14 +8,12 @@ import tensorflow as tf ...@@ -8,14 +8,12 @@ import tensorflow as tf
import six import six
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated
from .common import get_op_tensor_name from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname, from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path) is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'ChainInit', __all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed', 'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
'ParamRestore', 'DictRestore',
'JustCurrentSession', 'get_model_loader', 'TryResumeTraining'] 'JustCurrentSession', 'get_model_loader', 'TryResumeTraining']
...@@ -191,24 +189,24 @@ class DictRestore(SessionInit): ...@@ -191,24 +189,24 @@ class DictRestore(SessionInit):
Restore variables from a dictionary. Restore variables from a dictionary.
""" """
def __init__(self, param_dict): def __init__(self, variable_dict):
""" """
Args: Args:
param_dict (dict): a dict of {name: value} variable_dict (dict): a dict of {name: value}
""" """
assert isinstance(param_dict, dict), type(param_dict) assert isinstance(variable_dict, dict), type(variable_dict)
# use varname (with :0) for consistency # use varname (with :0) for consistency
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)} self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)}
def _run_init(self, sess): def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variable_names = set([k.name for k in variables]) variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self._prms))
intersect = variable_names & param_names intersect = variable_names & param_names
logger.info("Params to restore: {}".format(', '.join(map(str, intersect)))) logger.info("Variables to restore from dict: {}".format(', '.join(map(str, intersect))))
mismatch = MismatchLogger('graph', 'dict') mismatch = MismatchLogger('graph', 'dict')
for k in sorted(variable_names - param_names): for k in sorted(variable_names - param_names):
...@@ -222,12 +220,7 @@ class DictRestore(SessionInit): ...@@ -222,12 +220,7 @@ class DictRestore(SessionInit):
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from dict ...") logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect})
@deprecated("Use `DictRestore` instead!", "2017-09-01")
def ParamRestore(d):
return DictRestore(d)
class ChainInit(SessionInit): class ChainInit(SessionInit):
......
...@@ -16,7 +16,7 @@ from .multigpu import MultiGPUTrainerBase ...@@ -16,7 +16,7 @@ from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable from .utility import override_to_local_variable
__all__ = ['DistributedReplicatedTrainer', 'DistributedTrainerReplicated'] __all__ = ['DistributedTrainerReplicated']
class DistributedTrainerReplicated(MultiGPUTrainerBase): class DistributedTrainerReplicated(MultiGPUTrainerBase):
...@@ -336,8 +336,3 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -336,8 +336,3 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
@property @property
def vs_name_for_predictor(self): def vs_name_for_predictor(self):
return "tower0" return "tower0"
def DistributedReplicatedTrainer(*args, **kwargs):
logger.warn("DistributedReplicatedTrainer was renamed to DistributedTrainerReplicated!")
return DistributedTrainerReplicated(*args, **kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment