Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ac15b641
Commit
ac15b641
authored
Jan 16, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix lint
parent
be4759be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
13 deletions
+20
-13
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+13
-5
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+6
-7
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+1
-1
No files found.
tensorpack/callbacks/inference_runner.py
View file @
ac15b641
...
@@ -25,7 +25,7 @@ from .base import Callback
...
@@ -25,7 +25,7 @@ from .base import Callback
from
.group
import
Callbacks
from
.group
import
Callbacks
from
.inference
import
Inferencer
from
.inference
import
Inferencer
__all__
=
[
'InferenceRunner'
,
__all__
=
[
'InferenceRunner
Base'
,
'InferenceRunner
'
,
'DataParallelInferenceRunner'
]
'DataParallelInferenceRunner'
]
...
@@ -170,7 +170,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -170,7 +170,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
"""
"""
Inference with data-parallel support on multiple GPUs.
Inference with data-parallel support on multiple GPUs.
It will build one predict tower on each GPU, and run prediction
It will build one predict tower on each GPU, and run prediction
with a large total batch.
with a large total batch in parallel on all GPUs.
It will run the remainder (when the total size of input is not a multiple of #GPU)
sequentially.
"""
"""
def
__init__
(
self
,
input
,
infs
,
gpus
):
def
__init__
(
self
,
input
,
infs
,
gpus
):
"""
"""
...
@@ -188,6 +190,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -188,6 +190,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert
self
.
_size
>
0
,
"Input for DataParallelInferenceRunner must have a size!"
assert
self
.
_size
>
0
,
"Input for DataParallelInferenceRunner must have a size!"
self
.
_gpus
=
gpus
self
.
_gpus
=
gpus
self
.
_hooks
=
[]
self
.
_hooks_parallel
=
[]
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_handles
=
[]
self
.
_handles
=
[]
...
@@ -209,15 +214,18 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -209,15 +214,18 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# e.g. hooks from StagingInput will force the consumption
# e.g. hooks from StagingInput will force the consumption
# of nr_tower datapoints in every run.
# of nr_tower datapoints in every run.
input_hooks
=
self
.
_input_callbacks
.
get_hooks
()
input_hooks
=
self
.
_input_callbacks
.
get_hooks
()
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
+
input_hooks
self
.
_hooks
.
extend
([
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
+
input_hooks
)
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
+
input_hooks
self
.
_hooks_parallel
.
extend
([
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
+
input_hooks
)
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
setup_graph
(
self
.
trainer
)
inf
.
setup_graph
(
self
.
trainer
)
self
.
_input_callbacks
.
setup_graph
(
self
.
trainer
)
self
.
_input_callbacks
.
setup_graph
(
self
.
trainer
)
def
register_hook
(
self
,
h
):
def
register_hook
(
self
,
h
):
raise
NotImplementedError
(
"DataParallelInferenceRunner doesn't accept extra hooks!"
)
logger
.
info
(
"[DataParallelInferenceRunner] Registering hook {} on both parallel and sequential inference."
)
self
.
_hooks
.
append
(
h
)
self
.
_hooks_parallel
.
append
(
h
)
class
InferencerToHookDataParallel
(
InferencerToHook
):
class
InferencerToHookDataParallel
(
InferencerToHook
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
...
...
tensorpack/graph_builder/utils.py
View file @
ac15b641
...
@@ -7,7 +7,7 @@ from contextlib import contextmanager
...
@@ -7,7 +7,7 @@ from contextlib import contextmanager
import
operator
import
operator
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.common
import
get_tf_version_number
__all__
=
[
'LeastLoadedDeviceSetter'
,
__all__
=
[
'LeastLoadedDeviceSetter'
,
...
@@ -45,18 +45,17 @@ def override_to_local_variable(enable=True):
...
@@ -45,18 +45,17 @@ def override_to_local_variable(enable=True):
orig_vs
=
tf
.
get_variable_scope
()
orig_vs
=
tf
.
get_variable_scope
()
if
get_tf_version_number
()
>=
1.5
:
if
get_tf_version_number
()
>=
1.5
:
with
tf
.
variable_scope
(
with
tf
.
variable_scope
(
tf
.
get_variable_scope
()
,
orig_vs
,
custom_getter
=
custom_getter
,
custom_getter
=
custom_getter
,
auxiliary_name_scope
=
False
):
auxiliary_name_scope
=
False
):
yield
yield
else
:
else
:
if
get_tf_version_number
()
>=
1.2
:
if
get_tf_version_number
()
>=
1.2
:
ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
else
:
else
:
ns
=
tf
.
get_variable_scope
()
.
original_name_scope
ns
=
orig_vs
.
original_name_scope
with
tf
.
variable_scope
(
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
orig_vs
,
custom_getter
=
custom_getter
):
custom_getter
=
custom_getter
):
with
tf
.
name_scope
(
ns
+
'/'
):
with
tf
.
name_scope
(
ns
+
'/'
):
yield
yield
else
:
else
:
...
...
tensorpack/input_source/input_source.py
View file @
ac15b641
...
@@ -121,7 +121,7 @@ class EnqueueThread(ShareSessionThread):
...
@@ -121,7 +121,7 @@ class EnqueueThread(ShareSessionThread):
# self._size = queue.size()
# self._size = queue.size()
def
run
(
self
):
def
run
(
self
):
with
self
.
default_sess
()
as
sess
:
with
self
.
default_sess
():
try
:
try
:
self
.
reinitialize_dataflow
()
self
.
reinitialize_dataflow
()
while
True
:
while
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment