Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e46e6bca
Commit
e46e6bca
authored
Jul 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make DataParallelFeedInput runnable again
parent
d9b96535
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
19 deletions
+47
-19
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+18
-12
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+28
-6
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e46e6bca
...
...
@@ -175,16 +175,18 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
ret
)
# TODO completely broken now!
# TODO some scripts to test
class
DataParallelInferenceRunner
(
InferenceRunnerBase
):
"""
Broken. Don't use.
Inference by feeding datapoints in a data-parallel way to multiple GPUs.
Doesn't support remapped InputSource for now.
"""
def
__init__
(
self
,
input
,
infs
,
gpus
):
"""
Args:
input (DataParallelFeedInput or DataFlow)
gpus (list[int]): list of GPU id
"""
if
isinstance
(
input
,
DataFlow
):
tower_names
=
[
TowerContext
.
get_predict_tower_name
(
k
)
for
k
in
range
(
len
(
gpus
))]
...
...
@@ -197,8 +199,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def
_setup_graph
(
self
):
model
=
self
.
trainer
.
model
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
assert
len
(
self
.
_input_source
.
get_callbacks
())
==
0
,
\
"DataParallelInferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
def
build_tower
(
k
):
...
...
@@ -214,6 +214,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# setup feeds and hooks
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
cbs
=
self
.
_input_source
.
get_callbacks
()
self
.
_hooks_parallel
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
cbs
])
def
_duplicate_names_across_towers
(
self
,
names
):
ret
=
[]
...
...
@@ -262,15 +264,19 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
nr_tower
=
len
(
self
.
_gpus
)
with
tqdm
.
tqdm
(
total
=
total
,
**
get_tqdm_kwargs
())
as
pbar
:
while
total
>=
nr_tower
:
feed
=
self
.
_input_source
.
next_feed
()
self
.
_parallel_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
self
.
_parallel_hooked_sess
.
run
(
fetches
=
[])
pbar
.
update
(
nr_tower
)
total
-=
nr_tower
# take care of the rest
while
total
>
0
:
# TODO XXX doesn't support remap
feed
=
self
.
_input_source
.
_next_feed
(
cnt
=
1
)
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
pbar
.
update
(
1
)
total
-=
1
try
:
while
total
>
0
:
# TODO XXX doesn't support remap
feed
=
self
.
_input_source
.
next_feed
(
cnt
=
1
)
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
pbar
.
update
(
1
)
total
-=
1
except
AttributeError
:
logger
.
error
(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!"
)
logger
.
error
(
"[DataParallelInferenceRunner] Skipping the rest of the datapoints ..."
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
tensorpack/callbacks/summary.py
View file @
e46e6bca
...
...
@@ -27,7 +27,7 @@ class MovingAverageSummary(Callback):
def
_setup_graph
(
self
):
ops
=
tf
.
get_collection
(
self
.
_collection
)
logger
.
info
(
"Maintain moving averages of {}
op
s."
.
format
(
len
(
ops
)))
logger
.
info
(
"Maintain moving averages of {}
tensor
s."
.
format
(
len
(
ops
)))
self
.
ema_op
=
tf
.
group
(
*
ops
,
name
=
'summary_moving_averages'
)
self
.
_fetch
=
tf
.
train
.
SessionRunArgs
(
fetches
=
self
.
ema_op
)
...
...
tensorpack/train/input_source.py
View file @
e46e6bca
...
...
@@ -132,6 +132,7 @@ class FeedInput(InputSource):
def
_reset
(
self
):
self
.
_ds
.
reset_state
()
self
.
_itr
=
self
.
_ds
.
get_data
()
def
__init__
(
self
,
ds
):
"""
...
...
@@ -160,11 +161,31 @@ class FeedInput(InputSource):
return
[
self
.
_cb
]
# TODO completely broken now!
class
DataParallelFeedInput
(
FeedInput
):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
class
_DataParallelFeedCallback
(
Callback
):
def
__init__
(
self
,
ds
,
placeholders_per_tower
):
self
.
_ds
=
ds
self
.
_itr
=
self
.
_ds
.
get_data
()
self
.
_placehdrs_per_tower
=
placeholders_per_tower
self
.
_nr_tower
=
len
(
self
.
_placehdrs_per_tower
)
def
_reset
(
self
):
self
.
_ds
.
reset_state
()
self
.
_itr
=
self
.
_ds
.
get_data
()
def
_before_run
(
self
,
_
):
cnt
=
self
.
_nr_tower
feed
=
{}
for
t
in
range
(
cnt
):
dp
=
next
(
self
.
_itr
)
f
=
dict
(
zip
(
self
.
_placehdrs_per_tower
[
t
],
dp
))
feed
.
update
(
f
)
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
__init__
(
self
,
ds
,
tower_names
):
super
(
DataParallelFeedInput
,
self
)
.
__init__
(
ds
)
self
.
_tower_names
=
tower_names
...
...
@@ -176,6 +197,7 @@ class DataParallelFeedInput(FeedInput):
# build a list of placeholders for each tower
self
.
_placehdrs_per_tower
.
append
(
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
self
.
_cb
=
self
.
_DataParallelFeedCallback
(
self
.
_repeat_ds
,
self
.
_placehdrs_per_tower
)
self
.
reset_state
()
def
_get_input_tensors
(
self
):
...
...
@@ -183,16 +205,16 @@ class DataParallelFeedInput(FeedInput):
ctx
=
get_current_tower_context
()
return
self
.
_placehdrs_per_tower
[
ctx
.
index
]
def
next_feed
(
self
,
cnt
=
None
):
def
next_feed
(
self
,
cnt
=
1
):
"""
Args:
cnt: how many towers to feed to.
Defaults to the total number of towers
cnt: how many towers to feed to.
"""
if
cnt
is
None
:
cnt
=
self
.
_nr_tower
cnt
=
int
(
cnt
)
assert
cnt
<
self
.
_nr_tower
feed
=
{}
for
t
in
range
(
cnt
):
dp
=
next
(
self
.
data_produce
r
)
dp
=
next
(
self
.
_cb
.
_it
r
)
f
=
dict
(
zip
(
self
.
_placehdrs_per_tower
[
t
],
dp
))
feed
.
update
(
f
)
return
feed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment