Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ab2cd7e6
Commit
ab2cd7e6
authored
Nov 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update keras example
parent
edb1f6c3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
23 deletions
+37
-23
docs/tutorial/input-source.md
docs/tutorial/input-source.md
+4
-3
examples/mnist-keras-v2.py
examples/mnist-keras-v2.py
+1
-6
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-9
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+27
-5
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-0
No files found.
docs/tutorial/input-source.md
View file @
ab2cd7e6
...
@@ -72,8 +72,9 @@ Let's take a look at what users are asking for:
...
@@ -72,8 +72,9 @@ Let's take a look at what users are asking for:
*
[
Different ways to pad your data
](
https://github.com/tensorflow/tensorflow/issues/13969
)
*
[
Different ways to pad your data
](
https://github.com/tensorflow/tensorflow/issues/13969
)
*
[
Handle none values in data
](
https://github.com/tensorflow/tensorflow/issues/13865
)
*
[
Handle none values in data
](
https://github.com/tensorflow/tensorflow/issues/13865
)
*
[
Handle dataset that's not a multiple of batch size
](
https://github.com/tensorflow/tensorflow/issues/13745
)
*
[
Handle dataset that's not a multiple of batch size
](
https://github.com/tensorflow/tensorflow/issues/13745
)
*
[
Take variable-length np array
](
https://github.com/tensorflow/tensorflow/issues/13018
)
*
[
Different levels of determinism
](
https://github.com/tensorflow/tensorflow/issues/13932
)
*
[
Different levels of determinism
](
https://github.com/tensorflow/tensorflow/issues/13932
)
*
[
Sort/skip some data
](
https://github.com/tensorflow/tensorflow/issues/14250
)
*
[
Take variable-length np array
](
https://github.com/tensorflow/tensorflow/issues/13018
)
To support these features which could've been done with 3 lines of code in Python, you need either a new TF
To support these features which could've been done with 3 lines of code in Python, you need either a new TF
API, or ask
[
Dataset.from_generator
](
https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#from_generator
)
API, or ask
[
Dataset.from_generator
](
https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#from_generator
)
...
@@ -82,8 +83,8 @@ API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/ap
...
@@ -82,8 +83,8 @@ API, or ask [Dataset.from_generator](https://www.tensorflow.org/versions/r1.4/ap
It only makes sense to use TF to read data, if your data is originally very clean and well-formated.
It only makes sense to use TF to read data, if your data is originally very clean and well-formated.
If not, you may feel like writing a script to clean your data, but then you're almost writing a Python loader already!
If not, you may feel like writing a script to clean your data, but then you're almost writing a Python loader already!
Think about it: it's a waste of time to write a Python script to transform from raw data to
TFRecords
,
Think about it: it's a waste of time to write a Python script to transform from raw data to
clean format (e.g. TFRecords)
,
then a TF script to transform from
TFRecords
to tensors.
then a TF script to transform from
this format
to tensors.
The intermediate step (TFRecords) doesn't have to exist.
The intermediate step (TFRecords) doesn't have to exist.
You just need the right interface to connect Python to the graph directly, efficiently.
You just need the right interface to connect Python to the graph directly, efficiently.
`tensorpack.InputSource`
is such an interface.
`tensorpack.InputSource`
is such an interface.
...
...
examples/mnist-keras-v2.py
View file @
ab2cd7e6
...
@@ -57,11 +57,6 @@ if __name__ == '__main__':
...
@@ -57,11 +57,6 @@ if __name__ == '__main__':
metrics
=
[
'accuracy'
]
metrics
=
[
'accuracy'
]
)
)
M
.
fit
(
M
.
fit
(
callbacks
=
[
validation_data
=
dataset_test
,
ModelSaver
(),
InferenceRunner
(
dataset_test
,
[
ScalarStats
([
'total_loss'
,
'accuracy'
])]),
],
steps_per_epoch
=
dataset_train
.
size
(),
steps_per_epoch
=
dataset_train
.
size
(),
)
)
tensorpack/callbacks/saver.py
View file @
ab2cd7e6
...
@@ -8,7 +8,6 @@ import os
...
@@ -8,7 +8,6 @@ import os
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.common
import
get_tf_version_number
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
@@ -22,8 +21,7 @@ class ModelSaver(Callback):
...
@@ -22,8 +21,7 @@ class ModelSaver(Callback):
def
__init__
(
self
,
max_to_keep
=
10
,
def
__init__
(
self
,
max_to_keep
=
10
,
keep_checkpoint_every_n_hours
=
0.5
,
keep_checkpoint_every_n_hours
=
0.5
,
checkpoint_dir
=
None
,
checkpoint_dir
=
None
,
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
):
keep_recent
=
None
,
keep_freq
=
None
):
"""
"""
Args:
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
max_to_keep (int): the same as in ``tf.train.Saver``.
...
@@ -33,12 +31,6 @@ class ModelSaver(Callback):
...
@@ -33,12 +31,6 @@ class ModelSaver(Callback):
"""
"""
self
.
_max_to_keep
=
max_to_keep
self
.
_max_to_keep
=
max_to_keep
self
.
_keep_every_n_hours
=
keep_checkpoint_every_n_hours
self
.
_keep_every_n_hours
=
keep_checkpoint_every_n_hours
if
keep_recent
is
not
None
or
keep_freq
is
not
None
:
log_deprecated
(
"ModelSaver(keep_recent=, keep_freq=)"
,
"Use max_to_keep and keep_checkpoint_every_n_hours!"
)
if
keep_recent
is
not
None
:
self
.
_max_to_keep
=
keep_recent
if
keep_freq
is
not
None
:
self
.
_keep_every_n_hours
=
keep_freq
if
not
isinstance
(
var_collections
,
list
):
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
var_collections
=
[
var_collections
]
...
...
tensorpack/contrib/keras.py
View file @
ab2cd7e6
...
@@ -9,7 +9,9 @@ import keras
...
@@ -9,7 +9,9 @@ import keras
from
..graph_builder
import
InputDesc
from
..graph_builder
import
InputDesc
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.collection
import
freeze_collection
from
..tfutils.collection
import
freeze_collection
from
..callbacks
import
Callback
,
InferenceRunner
,
CallbackToHook
from
..callbacks
import
(
Callback
,
InferenceRunner
,
CallbackToHook
,
ScalarStats
,
ModelSaver
)
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
from
..utils.gpu
import
get_nr_gpu
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
...
@@ -107,6 +109,9 @@ class KerasModel(object):
...
@@ -107,6 +109,9 @@ class KerasModel(object):
"""
"""
Args:
Args:
model (keras.model.Model):
model (keras.model.Model):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
"""
self
.
model
=
model
self
.
model
=
model
if
trainer
is
None
:
if
trainer
is
None
:
...
@@ -117,10 +122,16 @@ class KerasModel(object):
...
@@ -117,10 +122,16 @@ class KerasModel(object):
trainer
=
SyncMultiGPUTrainerParameterServer
(
nr_gpu
)
trainer
=
SyncMultiGPUTrainerParameterServer
(
nr_gpu
)
assert
isinstance
(
trainer
,
Trainer
),
trainer
assert
isinstance
(
trainer
,
Trainer
),
trainer
self
.
trainer
=
trainer
self
.
input
=
input
self
.
input
=
input
self
.
trainer
=
trainer
def
compile
(
self
,
optimizer
,
loss
,
metrics
):
def
compile
(
self
,
optimizer
,
loss
,
metrics
):
"""
Args:
optimizer (tf.train.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
"""
self
.
_metrics
=
metrics
setup_keras_trainer
(
setup_keras_trainer
(
self
.
trainer
,
model
=
self
.
model
,
self
.
trainer
,
model
=
self
.
model
,
input
=
self
.
input
,
input
=
self
.
input
,
...
@@ -128,10 +139,21 @@ class KerasModel(object):
...
@@ -128,10 +139,21 @@ class KerasModel(object):
loss
=
loss
,
loss
=
loss
,
metrics
=
metrics
)
metrics
=
metrics
)
def
fit
(
self
,
**
kwargs
):
def
fit
(
self
,
validation_data
=
None
,
**
kwargs
):
"""
Args:
validation_data (DataFlow or InputSource): to be used for inference.
kwargs: same as `self.trainer.train_with_defaults`.
"""
callbacks
=
kwargs
.
pop
(
'callbacks'
,
[])
callbacks
=
kwargs
.
pop
(
'callbacks'
,
[])
callbacks
.
extend
(
self
.
get_default_callbacks
())
callbacks
.
extend
(
self
.
get_default_callbacks
())
self
.
trainer
.
train_with_defaults
(
**
kwargs
)
if
validation_data
is
not
None
:
callbacks
.
append
(
InferenceRunner
(
validation_data
,
ScalarStats
(
self
.
_metrics
+
[
'total_loss'
])))
self
.
trainer
.
train_with_defaults
(
callbacks
=
callbacks
,
**
kwargs
)
def
get_default_callbacks
(
self
):
def
get_default_callbacks
(
self
):
return
[]
return
[
ModelSaver
(
keep_checkpoint_every_n_hours
=
0.2
)
]
tensorpack/dataflow/common.py
View file @
ab2cd7e6
...
@@ -219,6 +219,10 @@ class FixedSizeData(ProxyDataFlow):
...
@@ -219,6 +219,10 @@ class FixedSizeData(ProxyDataFlow):
def
size
(
self
):
def
size
(
self
):
return
self
.
_size
return
self
.
_size
def
reset_state
(
self
):
super
(
FixedSizeData
,
self
)
.
reset_state
()
self
.
itr
=
self
.
ds
.
get_data
()
def
get_data
(
self
):
def
get_data
(
self
):
with
self
.
_guard
:
with
self
.
_guard
:
if
self
.
itr
is
None
:
if
self
.
itr
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment