Commit 972e298a authored by Yuxin Wu's avatar Yuxin Wu

fix saver v1/v2 issues

parent b9498a1a
......@@ -75,16 +75,6 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.path,
global_step=get_global_step(),
write_meta_graph=False)
# create a symbolic link for the latest model
latest = self.saver.last_checkpoints[-1]
basename = os.path.basename(latest)
linkname = os.path.join(os.path.dirname(latest), 'latest')
try:
os.unlink(linkname)
except OSError:
pass
os.symlink(basename, linkname)
except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!")
......
......@@ -84,7 +84,6 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
CheckGradient()
]
class ModelFromMetaGraph(ModelDesc):
"""
Load the whole exact TF graph from a saved meta_graph.
......
......@@ -54,27 +54,29 @@ class SaverRestore(SessionInit):
"""
def __init__(self, model_path, prefix=None):
"""
:param model_path: a model file or a ``checkpoint`` file.
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
"""
assert os.path.isfile(model_path)
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.get_checkpoint_state(
os.path.dirname(model_path)).model_checkpoint_path
assert os.path.isfile(model_path)
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index')
self.set_path(model_path)
self.prefix = prefix
def _init(self, sess):
logger.info(
"Restoring checkpoint from {}.".format(self.path))
"Restoring checkpoint from {} ...".format(self.path))
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = self._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
try:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
except:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
saver.restore(sess, self.path)
......
......@@ -10,7 +10,7 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
:param logits: NxC
:param label: N
:returns: a float32 vector of length N with 0/1 values, 1 meaning incorrect prediction
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name)
......@@ -95,9 +95,7 @@ def rms(x, name=None):
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
def huber_loss(x, delta=1, name=None):
if name is None:
name = 'huber_loss'
def huber_loss(x, delta=1, name='huber_loss'):
sqrcost = tf.square(x)
abscost = tf.abs(x)
return tf.reduce_sum(
......
......@@ -78,7 +78,7 @@ class SimpleTrainer(Trainer):
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
summary_moving_average(), name='train_op')
# create an infinte data producer
self.config.dataset.reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment