Commit 544516ae authored by Yuxin Wu's avatar Yuxin Wu

inception v3

parent f9bca2c4
This diff is collapsed.
......@@ -10,5 +10,5 @@ Only allow examples with reproducible and meaningful performancce.
+ [char-rnn for fun](char-rnn)
+ [DisturbLabel, because I don't believe the paper](DisturbLabel)
+ [DoReFa-Net, binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [GoogleNet-InceptionV1 with 71% accuracy](Inception)
+ [GoogleNet-InceptionV1 with 71% accuracy and InceptionV3 with 73.5% accuracy](Inception)
+ [ResNet for Cifar10 with similar accuracy, and for SVHN with state-of-the-art accuracy](ResNet)
......@@ -129,16 +129,16 @@ class SaverRestore(SessionInit):
chkpt_vars_used.add(name)
#vars_available.remove(name)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
logger.warn("Variable {} not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
for name in unused:
logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
logger.warn("Variable {} in checkpoint doesn't exist in the graph!".format(name))
return var_dict
class ParamRestore(SessionInit):
"""
Restore trainable variables from a dictionary.
Restore variables from a dictionary.
"""
def __init__(self, param_dict):
"""
......@@ -157,11 +157,12 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format(
', '.join(map(str, intersect))))
for k in variable_names - param_names:
logger.warn("Variable {} in the graph won't be restored!".format(k))
logger.warn("Variable {} in the graph not getting restored!".format(k))
for k in param_names - variable_names:
logger.warn("Param {} not found in this graph!".format(k))
logger.warn("Variable {} in the dict not found in this graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from param dict ...")
logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
......@@ -190,6 +191,6 @@ def dump_session_params(path):
for v in var:
name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Params to save to {}:".format(path))
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment