Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
00c47fa0
Commit
00c47fa0
authored
Feb 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add batchqueueinput
parent
c86cd15a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
41 deletions
+105
-41
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+4
-4
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+16
-20
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+85
-17
No files found.
tensorpack/dataflow/format.py
View file @
00c47fa0
...
@@ -67,10 +67,10 @@ class LMDBData(RNGDataFlow):
...
@@ -67,10 +67,10 @@ class LMDBData(RNGDataFlow):
lmdb_path (str): a directory or a file.
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True.
keys (list of str or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. `
'{:0>8d}'
` which will be
It can also be a format string e.g. `
`{:0>8d}`
` which will be
formatted with the indices from 0 to
`total_size - 1`
.
formatted with the indices from 0 to
*total_size - 1*
.
If not provided, it will then look in the database for `
__keys__
` which
If not provided, it will then look in the database for `
`__keys__`
` which
:func:`dump_dataflow_to_lmdb` used to store the list of keys.
:func:`dump_dataflow_to_lmdb` used to store the list of keys.
If still not found, it will iterate over the database to find
If still not found, it will iterate over the database to find
all the keys.
all the keys.
...
@@ -177,7 +177,7 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
...
@@ -177,7 +177,7 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
a :class:`LMDBDataDecoder` instance.
a :class:`LMDBDataDecoder` instance.
Example:
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
``ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')``
"""
"""
cpb
=
get_caffe_pb
()
cpb
=
get_caffe_pb
()
...
...
tensorpack/train/feedfree.py
View file @
00c47fa0
...
@@ -105,26 +105,22 @@ class SimpleFeedfreeTrainer(
...
@@ -105,26 +105,22 @@ class SimpleFeedfreeTrainer(
# self.train_op = tf.group(*self.dequed_inputs)
# self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
def
QueueInputTrainer
(
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
"""
A trainer which automatically wraps ``config.dataflow`` by a
A
wrapper
trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
:class:`QueueInput`.
"""
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
Single tower Trainer, takes input from a queue
Args:
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the
input_queue(tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
:class:`QueueInput` default.
"""
"""
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig(predict_tower=...) instead!"
)
"Use TrainConfig(predict_tower=...) instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
return
SimpleFeedfreeTrainer
(
config
)
tensorpack/train/input_data.py
View file @
00c47fa0
...
@@ -10,11 +10,13 @@ import six
...
@@ -10,11 +10,13 @@ import six
from
..dataflow
import
DataFlow
,
RepeatedData
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'InputData'
,
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
__all__
=
[
'InputData'
,
'FeedfreeInput'
,
'DummyConstantInput'
]
'QueueInput'
,
'BatchQueueInput'
,
'TensorInput'
,
'DummyConstantInput'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
...
@@ -90,9 +92,9 @@ class EnqueueThread(threading.Thread):
...
@@ -90,9 +92,9 @@ class EnqueueThread(threading.Thread):
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
def
run
(
self
):
def
run
(
self
):
self
.
dataflow
.
reset_state
()
try
:
with
self
.
sess
.
as_default
():
self
.
dataflow
.
reset_state
()
try
:
with
self
.
sess
.
as_default
()
:
while
True
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
...
@@ -100,22 +102,23 @@ class EnqueueThread(threading.Thread):
...
@@ -100,22 +102,23 @@ class EnqueueThread(threading.Thread):
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
:
except
tf
.
errors
.
CancelledError
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
finally
:
self
.
coord
.
request_stop
()
try
:
self
.
sess
.
run
(
self
.
close_op
)
except
RuntimeError
:
# session already closed
pass
pass
except
Exception
:
logger
.
info
(
"Enqueue Thread Exited."
)
logger
.
exception
(
"Exception in EnqueueThread:"
)
finally
:
self
.
coord
.
request_stop
()
try
:
self
.
sess
.
run
(
self
.
close_op
)
except
RuntimeError
:
# session already closed
pass
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInput
(
FeedfreeInput
):
class
QueueInput
(
FeedfreeInput
):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
""" Enqueue datapoints from a DataFlow to a TF queue.
tensors to the graph. """
And the model receives dequeued tensors.
"""
def
__init__
(
self
,
ds
,
queue
=
None
):
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
"""
...
@@ -144,6 +147,7 @@ class QueueInput(FeedfreeInput):
...
@@ -144,6 +147,7 @@ class QueueInput(FeedfreeInput):
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
print
(
ret
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
...
@@ -158,6 +162,70 @@ class QueueInput(FeedfreeInput):
...
@@ -158,6 +162,70 @@ class QueueInput(FeedfreeInput):
return
ret
return
ret
class
BatchQueueInput
(
FeedfreeInput
):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives batches formed by concatenating
dequeued tensors.
"""
def
__init__
(
self
,
ds
,
batch_size
,
queue
=
None
):
"""
Args:
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
ds
=
ds
self
.
batch_size
=
int
(
batch_size
)
def
size
(
self
):
return
self
.
ds
.
size
()
//
self
.
batch_size
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_input_vars
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
for
p
in
self
.
input_placehdrs
:
placehdrs_nobatch
.
append
(
tf
.
placeholder
(
dtype
=
p
.
dtype
,
shape
=
p
.
get_shape
()
.
as_list
()[
1
:],
name
=
get_op_tensor_name
(
p
.
name
)[
0
]
+
'-nobatch'
))
# dequeue_many requires fully-defined shapes
shape_err
=
"Use of BatchQueueInput requires input variables to have fully-defined "
"shapes except for the batch dimension"
shapes
=
[]
for
p
in
placehdrs_nobatch
:
assert
p
.
get_shape
()
.
is_fully_defined
(),
shape_err
shapes
.
append
(
p
.
get_shape
())
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
3000
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
shapes
=
shapes
,
name
=
'input_queue'
)
for
shp
in
self
.
queue
.
shapes
:
assert
shp
.
is_fully_defined
(),
shape_err
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
shp
=
v
.
get_shape
()
.
as_list
()
shp
[
0
]
=
self
.
batch_size
qv
.
set_shape
(
shp
)
return
ret
class
DummyConstantInput
(
FeedfreeInput
):
class
DummyConstantInput
(
FeedfreeInput
):
""" Input some constant variables. Only for debugging performance issues """
""" Input some constant variables. Only for debugging performance issues """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment