Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fa0d4dc6
Commit
fa0d4dc6
authored
Nov 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add dataflow_to_dataset (#397)
parent
985236c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
2 deletions
+25
-2
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+25
-2
No files found.
tensorpack/input_source/input_source.py
View file @
fa0d4dc6
...
...
@@ -14,7 +14,7 @@ from six.moves import range, zip
import
threading
from
.input_source_base
import
InputSource
from
..dataflow
import
DataFlow
,
RepeatedData
,
DataFlowTerminated
from
..dataflow
import
DataFlow
,
MapData
,
RepeatedData
,
DataFlowTerminated
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
...
...
@@ -437,7 +437,30 @@ class TFDatasetInput(FeedfreeInput):
self
.
_init_op
.
run
()
def
_get_input_tensors
(
self
):
return
self
.
_iterator
.
get_next
()
desc_shapes
=
[
k
.
shape
for
k
in
self
.
_desc
]
ret
=
self
.
_iterator
.
get_next
()
assert
len
(
ret
)
==
len
(
desc_shapes
)
for
t
,
shp
in
zip
(
ret
,
desc_shapes
):
t
.
set_shape
(
shp
)
return
ret
@
staticmethod
def
dataflow_to_dataset
(
df
,
types
):
"""
Wrap a dataflow to tf.data.Dataset.
Will reset df.
Args:
df (DataFlow)
types([tf.DType])
"""
assert
isinstance
(
df
,
DataFlow
),
df
assert
isinstance
(
types
,
(
list
,
tuple
)),
types
df
=
MapData
(
df
,
lambda
dp
:
tuple
(
dp
))
df
.
reset_state
()
ds
=
tf
.
data
.
Dataset
.
from_generator
(
df
.
get_data
,
tuple
(
types
))
return
ds
class
StagingInput
(
FeedfreeInput
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment