Commit 13e3c39a authored by Yuxin Wu's avatar Yuxin Wu

TENSORPACK_PIPEDIR

parent 12a7b7ff
......@@ -95,7 +95,7 @@ class PrefetchProcessZMQ(multiprocessing.Process):
class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_proc=1, pipedir='.'):
def __init__(self, ds, nr_proc=1, pipedir=None):
"""
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
......@@ -111,7 +111,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
assert os.path.isdir(pipedir)
if pipedir is None:
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(5) # a little bit faster than default, don't know why
self.socket.bind(self.pipename)
......
......@@ -78,7 +78,7 @@ def get_predict_func(config):
if config.input_data_mapping is None:
input_map = input_vars
else:
input_map = [input_vars[k] for k in config.input_data_mapping]
input_map = [input_vars[k] for k in config.input_data_mapping if k >= 0]
# check output_var_names against output_vars
output_vars = get_vars_by_names(output_var_names)
......
......@@ -150,7 +150,7 @@ class MultiThreadAsyncPredictor(object):
def put_task(self, inputs, callback=None):
"""
:param inputs: a data point (list of component) matching input_names (not batched)
:param callback: a callback to get called with the list of outputs
:param callback: a thread-safe callback to get called with the list of outputs
:returns: a Future of output."""
f = Future()
if callback is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment