Commit b2ec42a8 authored by Yuxin Wu's avatar Yuxin Wu

asyncpredictor accepts multiple input var

parent b9e2bd1b
...@@ -10,6 +10,7 @@ import weakref ...@@ -10,6 +10,7 @@ import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
import numpy as np import numpy as np
import six
from six.moves import queue from six.moves import queue
from ..utils.timer import * from ..utils.timer import *
...@@ -84,12 +85,13 @@ class SimulatorMaster(threading.Thread): ...@@ -84,12 +85,13 @@ class SimulatorMaster(threading.Thread):
class Experience(object): class Experience(object):
""" A transition of state, or experience""" """ A transition of state, or experience"""
def __init__(self, state, action, reward, misc=None): def __init__(self, state, action, reward, **kwargs):
""" misc: whatever other attribute you want to save""" """ kwargs: whatever other attribute you want to save"""
self.state = state self.state = state
self.action = action self.action = action
self.reward = reward self.reward = reward
self.misc = misc for k, v in six.iteritems(kwargs):
setattr(self, k, v)
def __init__(self, pipe_c2s, pipe_s2c): def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
...@@ -120,7 +122,7 @@ class SimulatorMaster(threading.Thread): ...@@ -120,7 +122,7 @@ class SimulatorMaster(threading.Thread):
atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context) atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
def run(self): def run(self):
self.clients = defaultdict(SimulatorMaster.ClientState) self.clients = defaultdict(self.ClientState)
while True: while True:
ident, msg = self.c2s_socket.recv_multipart() ident, msg = self.c2s_socket.recv_multipart()
client = self.clients[ident] client = self.clients[ident]
......
...@@ -130,7 +130,7 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -130,7 +130,7 @@ class HumanHyperParamSetter(HyperParamSetter):
""" """
super(HumanHyperParamSetter, self).__init__(param) super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name) self.file_name = os.path.join(logger.LOG_DIR, file_name)
logger.info("Use {} for hyperparam {}.".format( logger.info("Use {} to control hyperparam {}.".format(
self.file_name, self.param.readable_name)) self.file_name, self.param.readable_name))
def _get_value_to_set(self): def _get_value_to_set(self):
......
...@@ -81,29 +81,31 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -81,29 +81,31 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, batch_size=5): def __init__(self, queue, pred_func, id, nr_input_var, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
self.func = pred_func self.func = pred_func
self.daemon = True self.daemon = True
self.batch_size = batch_size self.batch_size = batch_size
self.nr_input_var = nr_input_var
self.id = id self.id = id
def run(self): def run(self):
def fetch(): def fetch():
batched, futures = [], [] batched, futures = [[] for _ in range(self.nr_input_var)], []
inp, f = self.queue.get() inp, f = self.queue.get()
batched.append(inp) for k in range(self.nr_input_var):
batched[k].append(inp[k])
futures.append(f) futures.append(f)
if self.batch_size == 1: # fill a batch
return batched, futures cnt = 1
while True: while cnt < self.batch_size:
try: try:
inp, f = self.queue.get_nowait() inp, f = self.queue.get_nowait()
batched.append(inp) for k in range(self.nr_input_var):
batched[k].append(inp[k])
futures.append(f) futures.append(f)
if len(batched) == self.batch_size: cnt += 1
break
except queue.Empty: except queue.Empty:
break break
return batched, futures return batched, futures
...@@ -111,7 +113,7 @@ class PredictorWorkerThread(threading.Thread): ...@@ -111,7 +113,7 @@ class PredictorWorkerThread(threading.Thread):
while True: while True:
batched, futures = fetch() batched, futures = fetch()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize() #print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched]) outputs = self.func(batched)
# debug, for speed testing # debug, for speed testing
#if self.xxx is None: #if self.xxx is None:
#outputs = self.func([batched]) #outputs = self.func([batched])
...@@ -135,7 +137,9 @@ class MultiThreadAsyncPredictor(object): ...@@ -135,7 +137,9 @@ class MultiThreadAsyncPredictor(object):
""" """
self.input_queue = queue.Queue(maxsize=nr_thread*10) self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [ self.threads = [
PredictorWorkerThread(self.input_queue, f, id, batch_size) PredictorWorkerThread(
self.input_queue, f, id,
len(input_names), batch_size=batch_size)
for id, f in enumerate( for id, f in enumerate(
trainer.get_predict_funcs( trainer.get_predict_funcs(
input_names, output_names, nr_thread))] input_names, output_names, nr_thread))]
...@@ -148,7 +152,10 @@ class MultiThreadAsyncPredictor(object): ...@@ -148,7 +152,10 @@ class MultiThreadAsyncPredictor(object):
t.start() t.start()
def put_task(self, inputs, callback=None): def put_task(self, inputs, callback=None):
""" return a Future of output.""" """
:params inputs: a data point (list of component) matching input_names (not batched)
:params callback: a callback to get called with the list of outputs
:returns: a Future of output."""
f = Future() f = Future()
if callback is not None: if callback is not None:
f.add_done_callback(callback) f.add_done_callback(callback)
......
...@@ -92,7 +92,7 @@ class MapGradient(GradientProcessor): ...@@ -92,7 +92,7 @@ class MapGradient(GradientProcessor):
def __init__(self, func, regex='.*'): def __init__(self, func, regex='.*'):
""" """
:param func: takes a tensor and returns a tensor :param func: takes a tensor and returns a tensor
;param regex: used to match variables. default to match all variables. :param regex: used to match variables. default to match all variables.
""" """
self.func = func self.func = func
if not regex.endswith('$'): if not regex.endswith('$'):
......
...@@ -107,6 +107,7 @@ class QueueInputTrainer(Trainer): ...@@ -107,6 +107,7 @@ class QueueInputTrainer(Trainer):
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100. Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0]. :param predict_tower: list of gpu idx to run prediction. default to be [0].
Use -1 for cpu.
""" """
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
...@@ -136,7 +137,7 @@ class QueueInputTrainer(Trainer): ...@@ -136,7 +137,7 @@ class QueueInputTrainer(Trainer):
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in self.predict_tower: for k in self.predict_tower:
logger.info("Building graph for predict towerp{}...".format(k)) logger.info("Building graph for predict towerp{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('towerp{}'.format(k)): tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False) self.model.build_graph(inputs, False)
......
...@@ -122,10 +122,10 @@ def subproc_call(cmd, timeout=None): ...@@ -122,10 +122,10 @@ def subproc_call(cmd, timeout=None):
shell=True, timeout=timeout) shell=True, timeout=timeout)
return output return output
except subprocess.TimeoutExpired as e: except subprocess.TimeoutExpired as e:
logger.warn("Timeout in evaluation!") logger.warn("Command timeout!")
logger.warn(e.output) logger.warn(e.output)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
logger.warn("Evaluation script failed: {}".format(e.returncode)) logger.warn("Commnad failed: {}".format(e.returncode))
logger.warn(e.output) logger.warn(e.output)
class OrderedContainer(object): class OrderedContainer(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment