Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3700a803
Commit
3700a803
authored
Aug 21, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Make FeedInput & QueueInput support dict-based dataflow (#768)
parent
209da29e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
3 deletions
+16
-3
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+16
-3
No files found.
tensorpack/input_source/input_source.py
View file @
3700a803
...
...
@@ -36,6 +36,18 @@ def _get_reset_callback(df):
return
CallbackFactory
(
setup_graph
=
lambda
_
:
df
.
reset_state
())
def
_make_feeds
(
placeholders
,
datapoint
):
assert
len
(
datapoint
)
==
len
(
placeholders
),
\
"Size of datapoint and placeholders are different: {} != {}"
.
format
(
len
(
datapoint
),
len
(
placeholders
))
if
isinstance
(
datapoint
,
(
list
,
tuple
)):
return
dict
(
zip
(
placeholders
,
datapoint
))
elif
isinstance
(
datapoint
,
dict
):
ret
=
{
p
:
datapoint
[
p
.
op
.
name
]
for
p
in
placeholders
}
return
ret
class
PlaceholderInput
(
InputSource
):
"""
Just produce placeholders as input tensors.
...
...
@@ -69,7 +81,7 @@ class FeedInput(InputSource):
def
_before_run
(
self
,
_
):
dp
=
next
(
self
.
_itr
)
assert
len
(
dp
)
==
len
(
self
.
_placeholders
),
"[FeedInput] datapoints and inputs are of different length!"
feed
=
dict
(
zip
(
self
.
_placeholders
,
dp
)
)
feed
=
_make_feeds
(
self
.
_placeholders
,
dp
)
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
_reset
(
self
):
...
...
@@ -142,7 +154,7 @@ class EnqueueThread(ShareSessionThread):
self
.
_running
.
wait
()
dp
=
next
(
self
.
_itr
)
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
)
)
feed
=
_make_feeds
(
self
.
placehdrs
,
dp
)
# _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self
.
op
.
run
(
feed_dict
=
feed
)
except
(
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
,
DataFlowTerminated
):
...
...
@@ -473,12 +485,13 @@ class TFDatasetInput(FeedfreeInput):
dataset, if the dataflow iterator can terminate.
Args:
df (DataFlow)
df (DataFlow)
: a dataflow which produces lists
types([tf.DType])
Returns:
(tf.data.Dataset)
"""
# TODO theoretically it can support dict
assert
isinstance
(
df
,
DataFlow
),
df
assert
isinstance
(
types
,
(
list
,
tuple
)),
types
df
=
MapData
(
df
,
lambda
dp
:
tuple
(
dp
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment