Commit 544516ae authored by Yuxin Wu's avatar Yuxin Wu

inception v3

parent f9bca2c4
This diff is collapsed.
...@@ -10,5 +10,5 @@ Only allow examples with reproducible and meaningful performancce. ...@@ -10,5 +10,5 @@ Only allow examples with reproducible and meaningful performancce.
+ [char-rnn for fun](char-rnn) + [char-rnn for fun](char-rnn)
+ [DisturbLabel, because I don't believe the paper](DisturbLabel) + [DisturbLabel, because I don't believe the paper](DisturbLabel)
+ [DoReFa-Net, binary / low-bitwidth CNN on ImageNet](DoReFa-Net) + [DoReFa-Net, binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [GoogleNet-InceptionV1 with 71% accuracy](Inception) + [GoogleNet-InceptionV1 with 71% accuracy and InceptionV3 with 73.5% accuracy](Inception)
+ [ResNet for Cifar10 with similar accuracy, and for SVHN with state-of-the-art accuracy](ResNet) + [ResNet for Cifar10 with similar accuracy, and for SVHN with state-of-the-art accuracy](ResNet)
...@@ -129,16 +129,16 @@ class SaverRestore(SessionInit): ...@@ -129,16 +129,16 @@ class SaverRestore(SessionInit):
chkpt_vars_used.add(name) chkpt_vars_used.add(name)
#vars_available.remove(name) #vars_available.remove(name)
else: else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name)) logger.warn("Variable {} not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available): if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used unused = vars_available - chkpt_vars_used
for name in unused: for name in unused:
logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name)) logger.warn("Variable {} in checkpoint doesn't exist in the graph!".format(name))
return var_dict return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
""" """
Restore trainable variables from a dictionary. Restore variables from a dictionary.
""" """
def __init__(self, param_dict): def __init__(self, param_dict):
""" """
...@@ -157,11 +157,12 @@ class ParamRestore(SessionInit): ...@@ -157,11 +157,12 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format( logger.info("Params to restore: {}".format(
', '.join(map(str, intersect)))) ', '.join(map(str, intersect))))
for k in variable_names - param_names: for k in variable_names - param_names:
logger.warn("Variable {} in the graph won't be restored!".format(k)) logger.warn("Variable {} in the graph not getting restored!".format(k))
for k in param_names - variable_names: for k in param_names - variable_names:
logger.warn("Param {} not found in this graph!".format(k)) logger.warn("Variable {} in the dict not found in this graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from param dict ...") logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
...@@ -190,6 +191,6 @@ def dump_session_params(path): ...@@ -190,6 +191,6 @@ def dump_session_params(path):
for v in var: for v in var:
name = v.name.replace(":0", "") name = v.name.replace(":0", "")
result[name] = v.eval() result[name] = v.eval()
logger.info("Params to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys())) logger.info(str(result.keys()))
np.save(path, result) np.save(path, result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment