Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
00c47fa0
Commit
00c47fa0
authored
Feb 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add batchqueueinput
parent
c86cd15a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
41 deletions
+105
-41
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+4
-4
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+16
-20
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+85
-17
No files found.
tensorpack/dataflow/format.py
View file @
00c47fa0
...
...
@@ -67,10 +67,10 @@ class LMDBData(RNGDataFlow):
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. `
'{:0>8d}'
` which will be
formatted with the indices from 0 to
`total_size - 1`
.
It can also be a format string e.g. `
`{:0>8d}`
` which will be
formatted with the indices from 0 to
*total_size - 1*
.
If not provided, it will then look in the database for `
__keys__
` which
If not provided, it will then look in the database for `
`__keys__`
` which
:func:`dump_dataflow_to_lmdb` used to store the list of keys.
If still not found, it will iterate over the database to find
all the keys.
...
...
@@ -177,7 +177,7 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
a :class:`LMDBDataDecoder` instance.
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
``ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')``
"""
cpb
=
get_caffe_pb
()
...
...
tensorpack/train/feedfree.py
View file @
00c47fa0
...
...
@@ -105,15 +105,11 @@ class SimpleFeedfreeTrainer(
# self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
def
QueueInputTrainer
(
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
A trainer which automatically wraps ``config.dataflow`` by a
A
wrapper
trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
Single tower Trainer, takes input from a queue
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
...
...
@@ -127,4 +123,4 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
return
SimpleFeedfreeTrainer
(
config
)
tensorpack/train/input_data.py
View file @
00c47fa0
...
...
@@ -10,11 +10,13 @@ import six
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'InputData'
,
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
'DummyConstantInput'
]
__all__
=
[
'InputData'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'TensorInput'
,
'DummyConstantInput'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -90,9 +92,9 @@ class EnqueueThread(threading.Thread):
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
def
run
(
self
):
try
:
self
.
dataflow
.
reset_state
()
with
self
.
sess
.
as_default
():
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
...
...
@@ -114,8 +116,9 @@ class EnqueueThread(threading.Thread):
class
QueueInput
(
FeedfreeInput
):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
tensors to the graph. """
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors.
"""
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
...
...
@@ -144,6 +147,7 @@ class QueueInput(FeedfreeInput):
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
print
(
ret
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
...
...
@@ -158,6 +162,70 @@ class QueueInput(FeedfreeInput):
return
ret
class
BatchQueueInput
(
FeedfreeInput
):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives batches formed by concatenating
dequeued tensors.
"""
def
__init__
(
self
,
ds
,
batch_size
,
queue
=
None
):
"""
Args:
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
ds
=
ds
self
.
batch_size
=
int
(
batch_size
)
def
size
(
self
):
return
self
.
ds
.
size
()
//
self
.
batch_size
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_input_vars
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
for
p
in
self
.
input_placehdrs
:
placehdrs_nobatch
.
append
(
tf
.
placeholder
(
dtype
=
p
.
dtype
,
shape
=
p
.
get_shape
()
.
as_list
()[
1
:],
name
=
get_op_tensor_name
(
p
.
name
)[
0
]
+
'-nobatch'
))
# dequeue_many requires fully-defined shapes
shape_err
=
"Use of BatchQueueInput requires input variables to have fully-defined "
"shapes except for the batch dimension"
shapes
=
[]
for
p
in
placehdrs_nobatch
:
assert
p
.
get_shape
()
.
is_fully_defined
(),
shape_err
shapes
.
append
(
p
.
get_shape
())
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
3000
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
shapes
=
shapes
,
name
=
'input_queue'
)
for
shp
in
self
.
queue
.
shapes
:
assert
shp
.
is_fully_defined
(),
shape_err
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
=
v
.
get_shape
()
.
as_list
()
shp
[
0
]
=
self
.
batch_size
qv
.
set_shape
(
shp
)
return
ret
class
DummyConstantInput
(
FeedfreeInput
):
""" Input some constant variables. Only for debugging performance issues """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment