Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9d0b28a0
Commit
9d0b28a0
authored
May 04, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add staging input (#140). Didn't see improvement
parent
dabebf69
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
1 deletion
+89
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+4
-0
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+85
-1
No files found.
tensorpack/train/feedfree.py
View file @
9d0b28a0
...
...
@@ -117,6 +117,10 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
else
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
# from tensorpack.train.input_data import QueueInput, FeedfreeInput, StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[64,224,224,3], [64]])
if
predict_tower
is
not
None
:
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig(predict_tower=...) instead!"
)
config
.
predict_tower
=
predict_tower
...
...
tensorpack/train/input_data.py
View file @
9d0b28a0
...
...
@@ -4,20 +4,26 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
tensorflow.contrib.staging
import
StagingArea
from
itertools
import
chain
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
six.moves
import
range
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.concurrency
import
ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
__all__
=
[
'InputData'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
]
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -160,6 +166,7 @@ class QueueInput(FeedfreeInput):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
#ret[0]= tf.Print(ret[0], [tf.reduce_mean(ret[0])], "asdf")
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
...
...
@@ -306,3 +313,80 @@ class ZMQInput(FeedfreeInput):
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
class
StagingInputWrapper
(
FeedfreeInput
):
class
StagingCallback
(
Callback
):
def
__init__
(
self
,
stage_op
,
unstage_op
,
nr_stage
):
self
.
nr_stage
=
nr_stage
self
.
stage_op
=
stage_op
# TODO make sure both stage/unstage are run, to avoid OOM
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
fetches
=
[
stage_op
])
def
_before_train
(
self
):
# pre-fill the staging area
for
k
in
range
(
self
.
nr_stage
):
self
.
stage_op
.
run
()
def
_before_run
(
self
,
ctx
):
return
self
.
fetches
def
__init__
(
self
,
input
,
devices
):
self
.
_input
=
input
assert
isinstance
(
input
,
FeedfreeInput
)
self
.
_devices
=
devices
self
.
_areas
=
[]
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_cnt_unstage
=
0
def
setup
(
self
,
model
):
self
.
_input
.
setup
(
model
)
self
.
setup_staging_areas
()
def
setup_training
(
self
,
trainer
):
super
(
StagingInputWrapper
,
self
)
.
setup_training
(
trainer
)
self
.
_input
.
setup_training
(
trainer
)
trainer
.
register_callback
(
StagingInputWrapper
.
StagingCallback
(
self
.
get_stage_op
(),
self
.
get_unstage_op
(),
5
))
def
setup_staging_areas
(
self
):
for
idx
,
device
in
enumerate
(
self
.
_devices
):
inputs
=
self
.
_input
.
get_input_tensors
()
dtypes
=
[
x
.
dtype
for
x
in
inputs
]
with
tf
.
device
(
device
):
stage
=
StagingArea
(
dtypes
,
shapes
=
None
)
self
.
_stage_ops
.
append
(
stage
.
put
(
inputs
))
self
.
_areas
.
append
(
stage
)
outputs
=
stage
.
get
()
for
vin
,
vout
in
zip
(
inputs
,
outputs
):
vout
.
set_shape
(
vin
.
get_shape
())
self
.
_unstage_ops
.
append
(
outputs
)
def
size
(
self
):
return
self
.
_input
.
size
()
def
get_input_tensors
(
self
):
assert
self
.
_cnt_unstage
<
len
(
self
.
_areas
)
assert
len
(
self
.
_areas
)
==
len
(
self
.
_devices
)
ret
=
self
.
_unstage_ops
[
self
.
_cnt_unstage
]
self
.
_cnt_unstage
+=
1
return
ret
@
staticmethod
def
get_staging_name
(
idx
):
return
'StagingArea{}'
.
format
(
idx
)
@
memoized
def
get_stage_op
(
self
):
return
tf
.
group
(
*
self
.
_stage_ops
)
@
memoized
def
get_unstage_op
(
self
):
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
return
tf
.
group
(
*
all_outputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment