Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c2a38de9
Commit
c2a38de9
authored
Aug 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Let InputSource.reset always get called. Add TFDatasetInput (#397)
parent
f2697f69
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
5 deletions
+51
-5
tensorpack/graph_builder/input_source.py
tensorpack/graph_builder/input_source.py
+42
-2
tensorpack/graph_builder/input_source_base.py
tensorpack/graph_builder/input_source_base.py
+7
-2
tensorpack/train/config.py
tensorpack/train/config.py
+2
-1
No files found.
tensorpack/graph_builder/input_source.py
View file @
c2a38de9
...
@@ -26,6 +26,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput',
...
@@ -26,6 +26,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput',
'FeedfreeInput'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'TFDatasetInput'
,
'StagingInputWrapper'
]
'StagingInputWrapper'
]
...
@@ -86,7 +87,6 @@ class FeedInput(InputSource):
...
@@ -86,7 +87,6 @@ class FeedInput(InputSource):
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder
(
prefix
=
''
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder
(
prefix
=
''
)
for
v
in
inputs
]
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_iter_ds
,
self
.
_all_placehdrs
)
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_iter_ds
,
self
.
_all_placehdrs
)
self
.
reset_state
()
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
return
self
.
_all_placehdrs
...
@@ -135,7 +135,6 @@ class DataParallelFeedInput(FeedInput):
...
@@ -135,7 +135,6 @@ class DataParallelFeedInput(FeedInput):
self
.
_placehdrs_per_tower
.
append
(
self
.
_placehdrs_per_tower
.
append
(
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
self
.
_cb
=
self
.
_DataParallelFeedCallback
(
self
.
_iter_ds
,
self
.
_placehdrs_per_tower
)
self
.
_cb
=
self
.
_DataParallelFeedCallback
(
self
.
_iter_ds
,
self
.
_placehdrs_per_tower
)
self
.
reset_state
()
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
# return placeholders for each tower
# return placeholders for each tower
...
@@ -415,6 +414,47 @@ class ZMQInput(TensorInput):
...
@@ -415,6 +414,47 @@ class ZMQInput(TensorInput):
"ZMQInput has to be used with InputDesc!"
"ZMQInput has to be used with InputDesc!"
class
TFDatasetInput
(
FeedfreeInput
):
"""
Use a :class:`tf.contrib.data.Dataset` instance as input.
Note:
In training, the dataset should be infinite (use :func:`repeat()`).
"""
def
__init__
(
self
,
dataset
):
"""
Args:
dataset (tf.contrib.data.Dataset):
"""
self
.
_dataset
=
dataset
def
_setup
(
self
,
inputs_desc
):
self
.
_desc
=
inputs_desc
types
=
self
.
_dataset
.
output_types
desc_types
=
tuple
([
k
.
type
for
k
in
inputs_desc
])
assert
len
(
types
)
==
len
(
desc_types
),
\
"Dataset and InputDesc has different length! {} != {}"
.
format
(
len
(
types
),
len
(
desc_types
))
assert
types
==
desc_types
,
\
"Types of dataset and InputDesc don't match! {} != {}"
.
format
(
str
(
types
),
str
(
desc_types
))
shapes
=
self
.
_dataset
.
output_shapes
desc_shapes
=
[
k
.
shape
for
k
in
inputs_desc
]
for
idx
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shapes
,
desc_shapes
)):
s2
=
tf
.
TensorShape
(
s2
)
assert
s2
.
is_compatible_with
(
s1
),
\
"InputDesc '{}' has incompatible shape with dataset! {} vs {}"
.
format
(
inputs_desc
[
idx
]
.
name
,
s2
,
s1
)
self
.
_iterator
=
self
.
_dataset
.
make_initializable_iterator
()
self
.
_init_op
=
self
.
_iterator
.
initializer
def
_reset_state
(
self
):
self
.
_init_op
.
run
()
def
_get_input_tensors
(
self
):
return
self
.
_iterator
.
get_next
()
class
StagingInputWrapper
(
FeedfreeInput
):
class
StagingInputWrapper
(
FeedfreeInput
):
"""
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
...
...
tensorpack/graph_builder/input_source_base.py
View file @
c2a38de9
...
@@ -9,6 +9,7 @@ import tensorflow as tf
...
@@ -9,6 +9,7 @@ import tensorflow as tf
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
._utils
import
get_sublist_by_names
,
get_tensors_inputs
from
._utils
import
get_sublist_by_names
,
get_tensors_inputs
from
..callbacks.base
import
CallbackFactory
__all__
=
[
'InputSource'
,
'remap_input_source'
]
__all__
=
[
'InputSource'
,
'remap_input_source'
]
...
@@ -56,14 +57,18 @@ class InputSource(object):
...
@@ -56,14 +57,18 @@ class InputSource(object):
Returns:
Returns:
list[Callback]: extra callbacks needed by this InputSource.
list[Callback]: extra callbacks needed by this InputSource.
"""
"""
return
self
.
_get_callbacks
()
return
[
CallbackFactory
(
before_train
=
lambda
_
:
self
.
reset_state
())]
+
self
.
_get_callbacks
()
def
_get_callbacks
(
self
):
def
_get_callbacks
(
self
):
return
[]
return
[]
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
"""
Reinitialize this InputSource.
Initialize/reinitialize this InputSource.
For training, it will get called by the trainer in `before_train` callbacks.
For inference, the :class:`InferenceRunner` will call it each time it does is triggered.
"""
"""
self
.
_reset_state
()
self
.
_reset_state
()
...
...
tensorpack/train/config.py
View file @
c2a38de9
...
@@ -125,7 +125,8 @@ class TrainConfig(object):
...
@@ -125,7 +125,8 @@ class TrainConfig(object):
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
except
NotImplementedError
:
except
NotImplementedError
:
logger
.
exception
(
"You must set `TrainConfig(steps_per_epoch)` if data.size() is not available."
)
logger
.
error
(
"You must set `TrainConfig(steps_per_epoch)` if data.size() is not available."
)
raise
else
:
else
:
steps_per_epoch
=
int
(
steps_per_epoch
)
steps_per_epoch
=
int
(
steps_per_epoch
)
self
.
steps_per_epoch
=
steps_per_epoch
self
.
steps_per_epoch
=
steps_per_epoch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment