Commit ab2cd7e6 authored by Yuxin Wu's avatar Yuxin Wu

update keras example

parent edb1f6c3
......@@ -72,8 +72,9 @@ Let's take a look at what users are asking for:
* [Different ways to pad your data](https://github.com/tensorflow/tensorflow/issues/13969)
* [Handle none values in data](https://github.com/tensorflow/tensorflow/issues/13865)
* [Handle dataset that's not a multiple of batch size](https://github.com/tensorflow/tensorflow/issues/13745)
* [Take variable-length np array](https://github.com/tensorflow/tensorflow/issues/13018)
* [Different levels of determinism](https://github.com/tensorflow/tensorflow/issues/13932)
* [Sort/skip some data](https://github.com/tensorflow/tensorflow/issues/14250)
* [Take variable-length np array](https://github.com/tensorflow/tensorflow/issues/13018)
To support these features which could've been done with 3 lines of code in Python, you need either a new TF
API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#from_generator)
......@@ -82,8 +83,8 @@ API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/ap
It only makes sense to use TF to read data, if your data is originally very clean and well-formated.
If not, you may feel like writing a script to clean your data, but then you're almost writing a Python loader already!
Think about it: it's a waste of time to write a Python script to transform from raw data to TFRecords,
then a TF script to transform from TFRecords to tensors.
Think about it: it's a waste of time to write a Python script to transform from raw data to clean format (e.g. TFRecords),
then a TF script to transform from this format to tensors.
The intermediate step (TFRecords) doesn't have to exist.
You just need the right interface to connect Python to the graph directly, efficiently.
`tensorpack.InputSource` is such an interface.
......
......@@ -57,11 +57,6 @@ if __name__ == '__main__':
metrics=['accuracy']
)
M.fit(
callbacks=[
ModelSaver(),
InferenceRunner(
dataset_test,
[ScalarStats(['total_loss', 'accuracy'])]),
],
validation_data=dataset_test,
steps_per_epoch=dataset_train.size(),
)
......@@ -8,7 +8,6 @@ import os
from .base import Callback
from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils.common import get_tf_version_number
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -22,8 +21,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES,
keep_recent=None, keep_freq=None):
var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
"""
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
......@@ -33,12 +31,6 @@ class ModelSaver(Callback):
"""
self._max_to_keep = max_to_keep
self._keep_every_n_hours = keep_checkpoint_every_n_hours
if keep_recent is not None or keep_freq is not None:
log_deprecated("ModelSaver(keep_recent=, keep_freq=)", "Use max_to_keep and keep_checkpoint_every_n_hours!")
if keep_recent is not None:
self._max_to_keep = keep_recent
if keep_freq is not None:
self._keep_every_n_hours = keep_freq
if not isinstance(var_collections, list):
var_collections = [var_collections]
......
......@@ -9,7 +9,9 @@ import keras
from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection
from ..callbacks import Callback, InferenceRunner, CallbackToHook
from ..callbacks import (
Callback, InferenceRunner, CallbackToHook,
ScalarStats, ModelSaver)
from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
......@@ -107,6 +109,9 @@ class KerasModel(object):
"""
Args:
model (keras.model.Model):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self.model = model
if trainer is None:
......@@ -117,10 +122,16 @@ class KerasModel(object):
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer
self.trainer = trainer
self.input = input
self.trainer = trainer
def compile(self, optimizer, loss, metrics):
"""
Args:
optimizer (tf.train.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
"""
self._metrics = metrics
setup_keras_trainer(
self.trainer, model=self.model,
input=self.input,
......@@ -128,10 +139,21 @@ class KerasModel(object):
loss=loss,
metrics=metrics)
def fit(self, **kwargs):
def fit(self, validation_data=None, **kwargs):
"""
Args:
validation_data (DataFlow or InputSource): to be used for inference.
kwargs: same as `self.trainer.train_with_defaults`.
"""
callbacks = kwargs.pop('callbacks', [])
callbacks.extend(self.get_default_callbacks())
self.trainer.train_with_defaults(**kwargs)
if validation_data is not None:
callbacks.append(
InferenceRunner(
validation_data, ScalarStats(self._metrics + ['total_loss'])))
self.trainer.train_with_defaults(callbacks=callbacks, **kwargs)
def get_default_callbacks(self):
return []
return [
ModelSaver(keep_checkpoint_every_n_hours=0.2)
]
......@@ -219,6 +219,10 @@ class FixedSizeData(ProxyDataFlow):
def size(self):
return self._size
def reset_state(self):
super(FixedSizeData, self).reset_state()
self.itr = self.ds.get_data()
def get_data(self):
with self._guard:
if self.itr is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment