Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ab2cd7e6
Commit
ab2cd7e6
authored
Nov 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update keras example
parent
edb1f6c3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
23 deletions
+37
-23
docs/tutorial/input-source.md
docs/tutorial/input-source.md
+4
-3
examples/mnist-keras-v2.py
examples/mnist-keras-v2.py
+1
-6
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-9
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+27
-5
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-0
No files found.
docs/tutorial/input-source.md
View file @
ab2cd7e6
...
...
@@ -72,8 +72,9 @@ Let's take a look at what users are asking for:
*
[
Different ways to pad your data
](
https://github.com/tensorflow/tensorflow/issues/13969
)
*
[
Handle none values in data
](
https://github.com/tensorflow/tensorflow/issues/13865
)
*
[
Handle dataset that's not a multiple of batch size
](
https://github.com/tensorflow/tensorflow/issues/13745
)
*
[
Take variable-length np array
](
https://github.com/tensorflow/tensorflow/issues/13018
)
*
[
Different levels of determinism
](
https://github.com/tensorflow/tensorflow/issues/13932
)
*
[
Sort/skip some data
](
https://github.com/tensorflow/tensorflow/issues/14250
)
*
[
Take variable-length np array
](
https://github.com/tensorflow/tensorflow/issues/13018
)
To support these features which could've been done with 3 lines of code in Python, you need either a new TF
API, or ask
[
Dataset.from_generator
](
https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#from_generator
)
...
...
@@ -82,8 +83,8 @@ API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/ap
It only makes sense to use TF to read data, if your data is originally very clean and well-formated.
If not, you may feel like writing a script to clean your data, but then you're almost writing a Python loader already!
Think about it: it's a waste of time to write a Python script to transform from raw data to
TFRecords
,
then a TF script to transform from
TFRecords
to tensors.
Think about it: it's a waste of time to write a Python script to transform from raw data to
clean format (e.g. TFRecords)
,
then a TF script to transform from
this format
to tensors.
The intermediate step (TFRecords) doesn't have to exist.
You just need the right interface to connect Python to the graph directly, efficiently.
`tensorpack.InputSource`
is such an interface.
...
...
examples/mnist-keras-v2.py
View file @
ab2cd7e6
...
...
@@ -57,11 +57,6 @@ if __name__ == '__main__':
metrics
=
[
'accuracy'
]
)
M
.
fit
(
callbacks
=
[
ModelSaver
(),
InferenceRunner
(
dataset_test
,
[
ScalarStats
([
'total_loss'
,
'accuracy'
])]),
],
validation_data
=
dataset_test
,
steps_per_epoch
=
dataset_train
.
size
(),
)
tensorpack/callbacks/saver.py
View file @
ab2cd7e6
...
...
@@ -8,7 +8,6 @@ import os
from
.base
import
Callback
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
from
..tfutils.common
import
get_tf_version_number
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
@@ -22,8 +21,7 @@ class ModelSaver(Callback):
def
__init__
(
self
,
max_to_keep
=
10
,
keep_checkpoint_every_n_hours
=
0.5
,
checkpoint_dir
=
None
,
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
keep_recent
=
None
,
keep_freq
=
None
):
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
):
"""
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
...
...
@@ -33,12 +31,6 @@ class ModelSaver(Callback):
"""
self
.
_max_to_keep
=
max_to_keep
self
.
_keep_every_n_hours
=
keep_checkpoint_every_n_hours
if
keep_recent
is
not
None
or
keep_freq
is
not
None
:
log_deprecated
(
"ModelSaver(keep_recent=, keep_freq=)"
,
"Use max_to_keep and keep_checkpoint_every_n_hours!"
)
if
keep_recent
is
not
None
:
self
.
_max_to_keep
=
keep_recent
if
keep_freq
is
not
None
:
self
.
_keep_every_n_hours
=
keep_freq
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
...
...
tensorpack/contrib/keras.py
View file @
ab2cd7e6
...
...
@@ -9,7 +9,9 @@ import keras
from
..graph_builder
import
InputDesc
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.collection
import
freeze_collection
from
..callbacks
import
Callback
,
InferenceRunner
,
CallbackToHook
from
..callbacks
import
(
Callback
,
InferenceRunner
,
CallbackToHook
,
ScalarStats
,
ModelSaver
)
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
...
...
@@ -107,6 +109,9 @@ class KerasModel(object):
"""
Args:
model (keras.model.Model):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self
.
model
=
model
if
trainer
is
None
:
...
...
@@ -117,10 +122,16 @@ class KerasModel(object):
trainer
=
SyncMultiGPUTrainerParameterServer
(
nr_gpu
)
assert
isinstance
(
trainer
,
Trainer
),
trainer
self
.
trainer
=
trainer
self
.
input
=
input
self
.
trainer
=
trainer
def
compile
(
self
,
optimizer
,
loss
,
metrics
):
"""
Args:
optimizer (tf.train.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
"""
self
.
_metrics
=
metrics
setup_keras_trainer
(
self
.
trainer
,
model
=
self
.
model
,
input
=
self
.
input
,
...
...
@@ -128,10 +139,21 @@ class KerasModel(object):
loss
=
loss
,
metrics
=
metrics
)
def
fit
(
self
,
**
kwargs
):
def
fit
(
self
,
validation_data
=
None
,
**
kwargs
):
"""
Args:
validation_data (DataFlow or InputSource): to be used for inference.
kwargs: same as `self.trainer.train_with_defaults`.
"""
callbacks
=
kwargs
.
pop
(
'callbacks'
,
[])
callbacks
.
extend
(
self
.
get_default_callbacks
())
self
.
trainer
.
train_with_defaults
(
**
kwargs
)
if
validation_data
is
not
None
:
callbacks
.
append
(
InferenceRunner
(
validation_data
,
ScalarStats
(
self
.
_metrics
+
[
'total_loss'
])))
self
.
trainer
.
train_with_defaults
(
callbacks
=
callbacks
,
**
kwargs
)
def
get_default_callbacks
(
self
):
return
[]
return
[
ModelSaver
(
keep_checkpoint_every_n_hours
=
0.2
)
]
tensorpack/dataflow/common.py
View file @
ab2cd7e6
...
...
@@ -219,6 +219,10 @@ class FixedSizeData(ProxyDataFlow):
def
size
(
self
):
return
self
.
_size
def
reset_state
(
self
):
super
(
FixedSizeData
,
self
)
.
reset_state
()
self
.
itr
=
self
.
ds
.
get_data
()
def
get_data
(
self
):
with
self
.
_guard
:
if
self
.
itr
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment