Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7a7295ee
Commit
7a7295ee
authored
Jun 08, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
PrintData when reset & support names in QueueInput
parent
1175aade
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
38 deletions
+30
-38
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+3
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+6
-26
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+3
-3
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+18
-9
No files found.
tensorpack/dataflow/base.py
View file @
7a7295ee
...
...
@@ -71,3 +71,6 @@ class ProxyDataFlow(DataFlow):
def
size
(
self
):
return
self
.
ds
.
size
()
def
get_data
(
self
):
return
self
.
ds
.
get_data
()
tensorpack/dataflow/common.py
View file @
7a7295ee
...
...
@@ -5,6 +5,7 @@
from
__future__
import
division
import
numpy
as
np
from
copy
import
copy
import
itertools
from
termcolor
import
colored
from
collections
import
deque
,
defaultdict
from
six.moves
import
range
,
map
...
...
@@ -623,7 +624,6 @@ class PrintData(ProxyDataFlow):
super
(
PrintData
,
self
)
.
__init__
(
ds
)
self
.
num
=
num
self
.
label
=
label
self
.
print_info
()
def
_analyze_input_data
(
self
,
el
,
k
,
depth
=
1
):
"""
...
...
@@ -670,30 +670,8 @@ class PrintData(ProxyDataFlow):
"""
Dump gathered debugging information to stdout.
"""
def
cutoff
(
gen
,
num
=
1
):
"""
Stop a generator after n iterations.
Args:
gen (PyGenObject): arbitrary generator
num (int, optional): number of maximal iterations
Yields:
element from generator object
"""
c
=
0
for
el
in
gen
:
yield
el
c
+=
1
if
c
==
num
:
break
ds
=
self
.
ds
ds
.
reset_state
()
msg
=
[
""
]
for
i
,
dummy
in
enumerate
(
cutoff
(
ds
.
get_data
(),
self
.
num
)):
for
i
,
dummy
in
enumerate
(
itertools
.
islice
(
self
.
ds
.
get_data
(),
self
.
num
)):
if
isinstance
(
dummy
,
list
):
msg
.
append
(
"datapoint
%
i<
%
i with
%
i components consists of"
%
(
i
,
self
.
num
,
len
(
dummy
)))
for
k
,
entry
in
enumerate
(
dummy
):
...
...
@@ -701,7 +679,9 @@ class PrintData(ProxyDataFlow):
label
=
""
if
self
.
label
is
""
else
" ("
+
self
.
label
+
")"
logger
.
info
(
colored
(
"DataFlow Info
%
s:"
%
label
,
'cyan'
)
+
'
\n
'
.
join
(
msg
))
# reset again after print
self
.
ds
.
reset_state
()
def
get_data
(
self
):
return
self
.
ds
.
get_data
()
def
reset_state
(
self
):
super
(
PrintData
,
self
)
.
reset_state
()
self
.
print_info
()
tensorpack/train/feedfree.py
View file @
7a7295ee
...
...
@@ -119,10 +119,10 @@ def QueueInputTrainer(config, input_queue=None):
input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
"""
if
config
.
dataflow
is
not
None
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
else
:
if
config
.
data
is
not
None
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
else
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
...
...
tensorpack/train/input_source.py
View file @
7a7295ee
...
...
@@ -161,7 +161,7 @@ class FeedfreeInput(InputSource):
e.g. by queue or other operations. """
def
reset_state
(
self
):
# TODO
cannot
reset
# TODO
no state to
reset
pass
def
next_feed
(
self
):
...
...
@@ -212,17 +212,19 @@ class QueueInput(FeedfreeInput):
And the model receives dequeued tensors.
"""
def
__init__
(
self
,
ds
,
queue
=
None
):
def
__init__
(
self
,
ds
,
queue
=
None
,
names
=
None
):
"""
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
names(list[str]): list of input names corresponding to the dataflow.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
ds
=
ds
self
.
_names
=
names
def
size
(
self
):
return
self
.
ds
.
size
()
...
...
@@ -231,13 +233,17 @@ class QueueInput(FeedfreeInput):
def
setup
(
self
,
model
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput has to be used with some InputDesc!"
if
self
.
_names
is
None
:
self
.
_queue_feedpoint
=
self
.
input_placehdrs
else
:
self
.
_queue_feedpoint
=
get_placeholders_by_names
(
self
.
input_placehdrs
,
self
.
_names
)
assert
len
(
self
.
_queue_feedpoint
)
>
0
,
\
"QueueInput has to be used with some inputs!"
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
_queue_feedpoint
],
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_queue_feedpoint
)
def
setup_training
(
self
,
trainer
):
super
(
QueueInput
,
self
)
.
setup_training
(
trainer
)
...
...
@@ -250,10 +256,13 @@ class QueueInput(FeedfreeInput):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
assert
len
(
ret
)
==
len
(
self
.
_queue_feedpoint
)
for
qv
,
v
in
zip
(
ret
,
self
.
_queue_feedpoint
):
qv
.
set_shape
(
v
.
get_shape
())
if
self
.
_names
is
None
:
return
ret
else
:
return
get_tensors_inputs
(
self
.
input_placehdrs
,
ret
,
self
.
_names
)
class
BatchQueueInput
(
FeedfreeInput
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment