Commit ab2cd7e6 authored by Yuxin Wu's avatar Yuxin Wu

update keras example

parent edb1f6c3
...@@ -72,8 +72,9 @@ Let's take a look at what users are asking for: ...@@ -72,8 +72,9 @@ Let's take a look at what users are asking for:
* [Different ways to pad your data](https://github.com/tensorflow/tensorflow/issues/13969) * [Different ways to pad your data](https://github.com/tensorflow/tensorflow/issues/13969)
* [Handle none values in data](https://github.com/tensorflow/tensorflow/issues/13865) * [Handle none values in data](https://github.com/tensorflow/tensorflow/issues/13865)
* [Handle dataset that's not a multiple of batch size](https://github.com/tensorflow/tensorflow/issues/13745) * [Handle dataset that's not a multiple of batch size](https://github.com/tensorflow/tensorflow/issues/13745)
* [Take variable-length np array](https://github.com/tensorflow/tensorflow/issues/13018)
* [Different levels of determinism](https://github.com/tensorflow/tensorflow/issues/13932) * [Different levels of determinism](https://github.com/tensorflow/tensorflow/issues/13932)
* [Sort/skip some data](https://github.com/tensorflow/tensorflow/issues/14250)
* [Take variable-length np array](https://github.com/tensorflow/tensorflow/issues/13018)
To support these features which could've been done with 3 lines of code in Python, you need either a new TF To support these features which could've been done with 3 lines of code in Python, you need either a new TF
API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#from_generator) API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#from_generator)
...@@ -82,8 +83,8 @@ API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/ap ...@@ -82,8 +83,8 @@ API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/ap
It only makes sense to use TF to read data, if your data is originally very clean and well-formated. It only makes sense to use TF to read data, if your data is originally very clean and well-formated.
If not, you may feel like writing a script to clean your data, but then you're almost writing a Python loader already! If not, you may feel like writing a script to clean your data, but then you're almost writing a Python loader already!
Think about it: it's a waste of time to write a Python script to transform from raw data to TFRecords, Think about it: it's a waste of time to write a Python script to transform from raw data to clean format (e.g. TFRecords),
then a TF script to transform from TFRecords to tensors. then a TF script to transform from this format to tensors.
The intermediate step (TFRecords) doesn't have to exist. The intermediate step (TFRecords) doesn't have to exist.
You just need the right interface to connect Python to the graph directly, efficiently. You just need the right interface to connect Python to the graph directly, efficiently.
`tensorpack.InputSource` is such an interface. `tensorpack.InputSource` is such an interface.
......
...@@ -57,11 +57,6 @@ if __name__ == '__main__': ...@@ -57,11 +57,6 @@ if __name__ == '__main__':
metrics=['accuracy'] metrics=['accuracy']
) )
M.fit( M.fit(
callbacks=[ validation_data=dataset_test,
ModelSaver(),
InferenceRunner(
dataset_test,
[ScalarStats(['total_loss', 'accuracy'])]),
],
steps_per_epoch=dataset_train.size(), steps_per_epoch=dataset_train.size(),
) )
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...@@ -22,8 +21,7 @@ class ModelSaver(Callback): ...@@ -22,8 +21,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10, def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5, keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None, checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES, var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
keep_recent=None, keep_freq=None):
""" """
Args: Args:
max_to_keep (int): the same as in ``tf.train.Saver``. max_to_keep (int): the same as in ``tf.train.Saver``.
...@@ -33,12 +31,6 @@ class ModelSaver(Callback): ...@@ -33,12 +31,6 @@ class ModelSaver(Callback):
""" """
self._max_to_keep = max_to_keep self._max_to_keep = max_to_keep
self._keep_every_n_hours = keep_checkpoint_every_n_hours self._keep_every_n_hours = keep_checkpoint_every_n_hours
if keep_recent is not None or keep_freq is not None:
log_deprecated("ModelSaver(keep_recent=, keep_freq=)", "Use max_to_keep and keep_checkpoint_every_n_hours!")
if keep_recent is not None:
self._max_to_keep = keep_recent
if keep_freq is not None:
self._keep_every_n_hours = keep_freq
if not isinstance(var_collections, list): if not isinstance(var_collections, list):
var_collections = [var_collections] var_collections = [var_collections]
......
...@@ -9,7 +9,9 @@ import keras ...@@ -9,7 +9,9 @@ import keras
from ..graph_builder import InputDesc from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..callbacks import Callback, InferenceRunner, CallbackToHook from ..callbacks import (
Callback, InferenceRunner, CallbackToHook,
ScalarStats, ModelSaver)
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
...@@ -107,6 +109,9 @@ class KerasModel(object): ...@@ -107,6 +109,9 @@ class KerasModel(object):
""" """
Args: Args:
model (keras.model.Model): model (keras.model.Model):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
""" """
self.model = model self.model = model
if trainer is None: if trainer is None:
...@@ -117,10 +122,16 @@ class KerasModel(object): ...@@ -117,10 +122,16 @@ class KerasModel(object):
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu) trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer assert isinstance(trainer, Trainer), trainer
self.trainer = trainer
self.input = input self.input = input
self.trainer = trainer
def compile(self, optimizer, loss, metrics): def compile(self, optimizer, loss, metrics):
"""
Args:
optimizer (tf.train.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
"""
self._metrics = metrics
setup_keras_trainer( setup_keras_trainer(
self.trainer, model=self.model, self.trainer, model=self.model,
input=self.input, input=self.input,
...@@ -128,10 +139,21 @@ class KerasModel(object): ...@@ -128,10 +139,21 @@ class KerasModel(object):
loss=loss, loss=loss,
metrics=metrics) metrics=metrics)
def fit(self, **kwargs): def fit(self, validation_data=None, **kwargs):
"""
Args:
validation_data (DataFlow or InputSource): to be used for inference.
kwargs: same as `self.trainer.train_with_defaults`.
"""
callbacks = kwargs.pop('callbacks', []) callbacks = kwargs.pop('callbacks', [])
callbacks.extend(self.get_default_callbacks()) callbacks.extend(self.get_default_callbacks())
self.trainer.train_with_defaults(**kwargs) if validation_data is not None:
callbacks.append(
InferenceRunner(
validation_data, ScalarStats(self._metrics + ['total_loss'])))
self.trainer.train_with_defaults(callbacks=callbacks, **kwargs)
def get_default_callbacks(self): def get_default_callbacks(self):
return [] return [
ModelSaver(keep_checkpoint_every_n_hours=0.2)
]
...@@ -219,6 +219,10 @@ class FixedSizeData(ProxyDataFlow): ...@@ -219,6 +219,10 @@ class FixedSizeData(ProxyDataFlow):
def size(self): def size(self):
return self._size return self._size
def reset_state(self):
super(FixedSizeData, self).reset_state()
self.itr = self.ds.get_data()
def get_data(self): def get_data(self):
with self._guard: with self._guard:
if self.itr is None: if self.itr is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment