Commit f5ba2692 authored by Yuxin Wu's avatar Yuxin Wu

fix #973

parent 683e43ff
...@@ -259,10 +259,15 @@ class KerasModel(object): ...@@ -259,10 +259,15 @@ class KerasModel(object):
""" """
Args: Args:
validation_data (DataFlow or InputSource): to be used for inference. validation_data (DataFlow or InputSource): to be used for inference.
The inference callback is added as the first in the callback list.
If you need to use it in a different order, please write it in the callback list manually.
kwargs: same as `self.trainer.train_with_defaults`. kwargs: same as `self.trainer.train_with_defaults`.
""" """
callbacks = kwargs.pop('callbacks', []) callbacks = kwargs.pop('callbacks', [])
if validation_data is not None: if validation_data is not None:
callbacks.append(InferenceRunner( # There is no way to guess where users want this callback. So we have to choose one.
# MinSaver may need results from this callback,
# so we put this callback at first.
callbacks.insert(0, InferenceRunner(
validation_data, ScalarStats(self._stats_to_inference))) validation_data, ScalarStats(self._stats_to_inference)))
self.trainer.train_with_defaults(callbacks=callbacks, **kwargs) self.trainer.train_with_defaults(callbacks=callbacks, **kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment