Commit d71184c4 authored by Yuxin Wu's avatar Yuxin Wu

fix #1312

parent 0a6dd4ae
...@@ -117,10 +117,10 @@ if __name__ == '__main__': ...@@ -117,10 +117,10 @@ if __name__ == '__main__':
data=FeedInput(dataset_train), data=FeedInput(dataset_train),
callbacks=[ callbacks=[
ModelSaver(), # save the model after every epoch ModelSaver(), # save the model after every epoch
MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
InferenceRunner( # run inference(for validation) after every epoch InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation dataset_test, # the DataFlow instance used for validation
ScalarStats(['cross_entropy_loss', 'accuracy'])), ScalarStats(['cross_entropy_loss', 'accuracy'])),
MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
], ],
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
max_epoch=100, max_epoch=100,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment