Commit 972e298a authored by Yuxin Wu's avatar Yuxin Wu

fix saver v1/v2 issues

parent b9498a1a
...@@ -75,16 +75,6 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -75,16 +75,6 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.path, self.path,
global_step=get_global_step(), global_step=get_global_step(),
write_meta_graph=False) write_meta_graph=False)
# create a symbolic link for the latest model
latest = self.saver.last_checkpoints[-1]
basename = os.path.basename(latest)
linkname = os.path.join(os.path.dirname(latest), 'latest')
try:
os.unlink(linkname)
except OSError:
pass
os.symlink(basename, linkname)
except (OSError, IOError): # disk error sometimes.. just ignore it except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
......
...@@ -84,7 +84,6 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i ...@@ -84,7 +84,6 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
CheckGradient() CheckGradient()
] ]
class ModelFromMetaGraph(ModelDesc): class ModelFromMetaGraph(ModelDesc):
""" """
Load the whole exact TF graph from a saved meta_graph. Load the whole exact TF graph from a saved meta_graph.
......
...@@ -54,28 +54,30 @@ class SaverRestore(SessionInit): ...@@ -54,28 +54,30 @@ class SaverRestore(SessionInit):
""" """
def __init__(self, model_path, prefix=None): def __init__(self, model_path, prefix=None):
""" """
:param model_path: a model file or a ``checkpoint`` file. :param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint :param prefix: add a `prefix/` for every variable in this checkpoint
""" """
assert os.path.isfile(model_path)
if os.path.basename(model_path) == model_path: if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 model_path = os.path.join('.', model_path) # avoid #4921
if os.path.basename(model_path) == 'checkpoint': if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.get_checkpoint_state( model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
os.path.dirname(model_path)).model_checkpoint_path # to be consistent with either v1 or v2
assert os.path.isfile(model_path) assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index')
self.set_path(model_path) self.set_path(model_path)
self.prefix = prefix self.prefix = prefix
def _init(self, sess): def _init(self, sess):
logger.info( logger.info(
"Restoring checkpoint from {}.".format(self.path)) "Restoring checkpoint from {} ...".format(self.path))
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = self._get_vars_to_restore_multimap(chkpt_vars) vars_map = self._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map): for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error: # multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name # training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver = tf.train.Saver(var_list=dic, name=str(id(dic))) try:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
except:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
saver.restore(sess, self.path) saver.restore(sess, self.path)
def set_path(self, model_path): def set_path(self, model_path):
......
...@@ -10,7 +10,7 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): ...@@ -10,7 +10,7 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
""" """
:param logits: NxC :param logits: NxC
:param label: N :param label: N
:returns: a float32 vector of length N with 0/1 values, 1 meaning incorrect prediction :returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
""" """
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name) tf.float32, name=name)
...@@ -95,9 +95,7 @@ def rms(x, name=None): ...@@ -95,9 +95,7 @@ def rms(x, name=None):
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
def huber_loss(x, delta=1, name=None): def huber_loss(x, delta=1, name='huber_loss'):
if name is None:
name = 'huber_loss'
sqrcost = tf.square(x) sqrcost = tf.square(x)
abscost = tf.abs(x) abscost = tf.abs(x)
return tf.reduce_sum( return tf.reduce_sum(
......
...@@ -78,7 +78,7 @@ class SimpleTrainer(Trainer): ...@@ -78,7 +78,7 @@ class SimpleTrainer(Trainer):
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average()) summary_moving_average(), name='train_op')
# create an infinte data producer # create an infinte data producer
self.config.dataset.reset_state() self.config.dataset.reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment