Commit 377fba1f authored by Yuxin Wu's avatar Yuxin Wu

clean-up some deprecation

parent 73b63247
......@@ -353,7 +353,6 @@ def process_signature(app, what, name, obj, options, signature,
def autodoc_skip_member(app, what, name, obj, skip, options):
if name in [
'DistributedReplicatedTrainer',
'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer',
'FeedfreeTrainerBase',
......@@ -369,7 +368,6 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'GaussianDeform',
'dump_chkpt_vars',
'VisualQA',
'ParamRestore',
'huber_loss'
]:
return True
......
......@@ -8,14 +8,12 @@ import tensorflow as tf
import six
from ..utils import logger
from ..utils.develop import deprecated
from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'DictRestore',
'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
'JustCurrentSession', 'get_model_loader', 'TryResumeTraining']
......@@ -191,24 +189,24 @@ class DictRestore(SessionInit):
Restore variables from a dictionary.
"""
def __init__(self, param_dict):
def __init__(self, variable_dict):
"""
Args:
param_dict (dict): a dict of {name: value}
variable_dict (dict): a dict of {name: value}
"""
assert isinstance(param_dict, dict), type(param_dict)
assert isinstance(variable_dict, dict), type(variable_dict)
# use varname (with :0) for consistency
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)}
def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms))
param_names = set(six.iterkeys(self._prms))
intersect = variable_names & param_names
logger.info("Params to restore: {}".format(', '.join(map(str, intersect))))
logger.info("Variables to restore from dict: {}".format(', '.join(map(str, intersect))))
mismatch = MismatchLogger('graph', 'dict')
for k in sorted(variable_names - param_names):
......@@ -222,12 +220,7 @@ class DictRestore(SessionInit):
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
@deprecated("Use `DictRestore` instead!", "2017-09-01")
def ParamRestore(d):
return DictRestore(d)
upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect})
class ChainInit(SessionInit):
......
......@@ -16,7 +16,7 @@ from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedReplicatedTrainer', 'DistributedTrainerReplicated']
__all__ = ['DistributedTrainerReplicated']
class DistributedTrainerReplicated(MultiGPUTrainerBase):
......@@ -336,8 +336,3 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
@property
def vs_name_for_predictor(self):
return "tower0"
def DistributedReplicatedTrainer(*args, **kwargs):
logger.warn("DistributedReplicatedTrainer was renamed to DistributedTrainerReplicated!")
return DistributedTrainerReplicated(*args, **kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment