Commit ac15b641 authored by Yuxin Wu's avatar Yuxin Wu

fix lint

parent be4759be
......@@ -25,7 +25,7 @@ from .base import Callback
from .group import Callbacks
from .inference import Inferencer
__all__ = ['InferenceRunner',
__all__ = ['InferenceRunnerBase', 'InferenceRunner',
'DataParallelInferenceRunner']
......@@ -170,7 +170,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
"""
Inference with data-parallel support on multiple GPUs.
It will build one predict tower on each GPU, and run prediction
with a large total batch.
with a large total batch in parallel on all GPUs.
It will run the remainder (when the total size of input is not a multiple of #GPU)
sequentially.
"""
def __init__(self, input, infs, gpus):
"""
......@@ -188,6 +190,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert self._size > 0, "Input for DataParallelInferenceRunner must have a size!"
self._gpus = gpus
self._hooks = []
self._hooks_parallel = []
def _setup_graph(self):
self._handles = []
......@@ -209,15 +214,18 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# e.g. hooks from StagingInput will force the consumption
# of nr_tower datapoints in every run.
input_hooks = self._input_callbacks.get_hooks()
self._hooks = [self._build_hook(inf) for inf in self.infs] + input_hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] + input_hooks
self._hooks.extend([self._build_hook(inf) for inf in self.infs] + input_hooks)
self._hooks_parallel.extend([self._build_hook_parallel(inf) for inf in self.infs] + input_hooks)
for inf in self.infs:
inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer)
def register_hook(self, h):
raise NotImplementedError("DataParallelInferenceRunner doesn't accept extra hooks!")
logger.info(
"[DataParallelInferenceRunner] Registering hook {} on both parallel and sequential inference.")
self._hooks.append(h)
self._hooks_parallel.append(h)
class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size):
......
......@@ -45,7 +45,7 @@ def override_to_local_variable(enable=True):
orig_vs = tf.get_variable_scope()
if get_tf_version_number() >= 1.5:
with tf.variable_scope(
tf.get_variable_scope(),
orig_vs,
custom_getter=custom_getter,
auxiliary_name_scope=False):
yield
......@@ -53,10 +53,9 @@ def override_to_local_variable(enable=True):
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
else:
ns = tf.get_variable_scope().original_name_scope
ns = orig_vs.original_name_scope
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=custom_getter):
orig_vs, custom_getter=custom_getter):
with tf.name_scope(ns + '/'):
yield
else:
......
......@@ -121,7 +121,7 @@ class EnqueueThread(ShareSessionThread):
# self._size = queue.size()
def run(self):
with self.default_sess() as sess:
with self.default_sess():
try:
self.reinitialize_dataflow()
while True:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment