Commit ac15b641 authored by Yuxin Wu's avatar Yuxin Wu

fix lint

parent be4759be
...@@ -25,7 +25,7 @@ from .base import Callback ...@@ -25,7 +25,7 @@ from .base import Callback
from .group import Callbacks from .group import Callbacks
from .inference import Inferencer from .inference import Inferencer
__all__ = ['InferenceRunner', __all__ = ['InferenceRunnerBase', 'InferenceRunner',
'DataParallelInferenceRunner'] 'DataParallelInferenceRunner']
...@@ -170,7 +170,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -170,7 +170,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
Inference with data-parallel support on multiple GPUs. Inference with data-parallel support on multiple GPUs.
It will build one predict tower on each GPU, and run prediction It will build one predict tower on each GPU, and run prediction
with a large total batch. with a large total batch in parallel on all GPUs.
It will run the remainder (when the total size of input is not a multiple of #GPU)
sequentially.
""" """
def __init__(self, input, infs, gpus): def __init__(self, input, infs, gpus):
""" """
...@@ -188,6 +190,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -188,6 +190,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert self._size > 0, "Input for DataParallelInferenceRunner must have a size!" assert self._size > 0, "Input for DataParallelInferenceRunner must have a size!"
self._gpus = gpus self._gpus = gpus
self._hooks = []
self._hooks_parallel = []
def _setup_graph(self): def _setup_graph(self):
self._handles = [] self._handles = []
...@@ -209,15 +214,18 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -209,15 +214,18 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# e.g. hooks from StagingInput will force the consumption # e.g. hooks from StagingInput will force the consumption
# of nr_tower datapoints in every run. # of nr_tower datapoints in every run.
input_hooks = self._input_callbacks.get_hooks() input_hooks = self._input_callbacks.get_hooks()
self._hooks = [self._build_hook(inf) for inf in self.infs] + input_hooks self._hooks.extend([self._build_hook(inf) for inf in self.infs] + input_hooks)
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] + input_hooks self._hooks_parallel.extend([self._build_hook_parallel(inf) for inf in self.infs] + input_hooks)
for inf in self.infs: for inf in self.infs:
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer) self._input_callbacks.setup_graph(self.trainer)
def register_hook(self, h): def register_hook(self, h):
raise NotImplementedError("DataParallelInferenceRunner doesn't accept extra hooks!") logger.info(
"[DataParallelInferenceRunner] Registering hook {} on both parallel and sequential inference.")
self._hooks.append(h)
self._hooks_parallel.append(h)
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size): def __init__(self, inf, fetches, size):
......
...@@ -7,7 +7,7 @@ from contextlib import contextmanager ...@@ -7,7 +7,7 @@ from contextlib import contextmanager
import operator import operator
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
__all__ = ['LeastLoadedDeviceSetter', __all__ = ['LeastLoadedDeviceSetter',
...@@ -45,18 +45,17 @@ def override_to_local_variable(enable=True): ...@@ -45,18 +45,17 @@ def override_to_local_variable(enable=True):
orig_vs = tf.get_variable_scope() orig_vs = tf.get_variable_scope()
if get_tf_version_number() >= 1.5: if get_tf_version_number() >= 1.5:
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), orig_vs,
custom_getter=custom_getter, custom_getter=custom_getter,
auxiliary_name_scope=False): auxiliary_name_scope=False):
yield yield
else: else:
if get_tf_version_number() >= 1.2: if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope() ns = tf.get_default_graph().get_name_scope()
else: else:
ns = tf.get_variable_scope().original_name_scope ns = orig_vs.original_name_scope
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), orig_vs, custom_getter=custom_getter):
custom_getter=custom_getter):
with tf.name_scope(ns + '/'): with tf.name_scope(ns + '/'):
yield yield
else: else:
......
...@@ -121,7 +121,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -121,7 +121,7 @@ class EnqueueThread(ShareSessionThread):
# self._size = queue.size() # self._size = queue.size()
def run(self): def run(self):
with self.default_sess() as sess: with self.default_sess():
try: try:
self.reinitialize_dataflow() self.reinitialize_dataflow()
while True: while True:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment