Commit b2ec42a8 authored by Yuxin Wu's avatar Yuxin Wu

asyncpredictor accepts multiple input var

parent b9e2bd1b
......@@ -10,6 +10,7 @@ import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import numpy as np
import six
from six.moves import queue
from ..utils.timer import *
......@@ -84,12 +85,13 @@ class SimulatorMaster(threading.Thread):
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, misc=None):
""" misc: whatever other attribute you want to save"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
self.misc = misc
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__()
......@@ -120,7 +122,7 @@ class SimulatorMaster(threading.Thread):
atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
def run(self):
self.clients = defaultdict(SimulatorMaster.ClientState)
self.clients = defaultdict(self.ClientState)
while True:
ident, msg = self.c2s_socket.recv_multipart()
client = self.clients[ident]
......
......@@ -130,7 +130,7 @@ class HumanHyperParamSetter(HyperParamSetter):
"""
super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name)
logger.info("Use {} for hyperparam {}.".format(
logger.info("Use {} to control hyperparam {}.".format(
self.file_name, self.param.readable_name))
def _get_value_to_set(self):
......
......@@ -81,29 +81,31 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, batch_size=5):
def __init__(self, queue, pred_func, id, nr_input_var, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
self.func = pred_func
self.daemon = True
self.batch_size = batch_size
self.nr_input_var = nr_input_var
self.id = id
def run(self):
def fetch():
batched, futures = [], []
batched, futures = [[] for _ in range(self.nr_input_var)], []
inp, f = self.queue.get()
batched.append(inp)
for k in range(self.nr_input_var):
batched[k].append(inp[k])
futures.append(f)
if self.batch_size == 1:
return batched, futures
while True:
# fill a batch
cnt = 1
while cnt < self.batch_size:
try:
inp, f = self.queue.get_nowait()
batched.append(inp)
for k in range(self.nr_input_var):
batched[k].append(inp[k])
futures.append(f)
if len(batched) == self.batch_size:
break
cnt += 1
except queue.Empty:
break
return batched, futures
......@@ -111,7 +113,7 @@ class PredictorWorkerThread(threading.Thread):
while True:
batched, futures = fetch()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched])
outputs = self.func(batched)
# debug, for speed testing
#if self.xxx is None:
#outputs = self.func([batched])
......@@ -135,7 +137,9 @@ class MultiThreadAsyncPredictor(object):
"""
self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [
PredictorWorkerThread(self.input_queue, f, id, batch_size)
PredictorWorkerThread(
self.input_queue, f, id,
len(input_names), batch_size=batch_size)
for id, f in enumerate(
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
......@@ -148,7 +152,10 @@ class MultiThreadAsyncPredictor(object):
t.start()
def put_task(self, inputs, callback=None):
""" return a Future of output."""
"""
:params inputs: a data point (list of component) matching input_names (not batched)
:params callback: a callback to get called with the list of outputs
:returns: a Future of output."""
f = Future()
if callback is not None:
f.add_done_callback(callback)
......
......@@ -92,7 +92,7 @@ class MapGradient(GradientProcessor):
def __init__(self, func, regex='.*'):
"""
:param func: takes a tensor and returns a tensor
;param regex: used to match variables. default to match all variables.
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
......
......@@ -107,6 +107,7 @@ class QueueInputTrainer(Trainer):
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0].
Use -1 for cpu.
"""
super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars()
......@@ -136,7 +137,7 @@ class QueueInputTrainer(Trainer):
tf.get_variable_scope().reuse_variables()
for k in self.predict_tower:
logger.info("Building graph for predict towerp{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False)
......
......@@ -122,10 +122,10 @@ def subproc_call(cmd, timeout=None):
shell=True, timeout=timeout)
return output
except subprocess.TimeoutExpired as e:
logger.warn("Timeout in evaluation!")
logger.warn("Command timeout!")
logger.warn(e.output)
except subprocess.CalledProcessError as e:
logger.warn("Evaluation script failed: {}".format(e.returncode))
logger.warn("Commnad failed: {}".format(e.returncode))
logger.warn(e.output)
class OrderedContainer(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment