Commit 339f2173 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] always save the last model

parent bedec8cd
...@@ -365,7 +365,7 @@ if __name__ == '__main__': ...@@ -365,7 +365,7 @@ if __name__ == '__main__':
model=Model(), model=Model(),
data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)), data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)),
callbacks=[ callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=5), ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1),
# linear warmup # linear warmup
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', 'learning_rate',
......
...@@ -25,7 +25,7 @@ class ModelSaver(Callback): ...@@ -25,7 +25,7 @@ class ModelSaver(Callback):
""" """
Args: Args:
max_to_keep (int): the same as in ``tf.train.Saver``. max_to_keep (int): the same as in ``tf.train.Saver``.
keep_checkpoint_every_n_hours (int): the same as in ``tf.train.Saver``. keep_checkpoint_every_n_hours (float): the same as in ``tf.train.Saver``.
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``. checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save. var_collections (str or list of str): collection of the variables (or list of collections) to save.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment