Commit 126fd26a authored by Yuxin Wu's avatar Yuxin Wu

[WIP] queuetrainer + multigputrainer, predict tower

parent 254af98b
......@@ -9,6 +9,7 @@ import threading
import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import numpy as np
from six.moves import queue
from ..utils.timer import *
......@@ -53,12 +54,14 @@ class SimulatorProcess(multiprocessing.Process):
#cnt = 0
while True:
state = player.current_state()
c2s_socket.send(dumps(state), copy=False)
#c2s_socket.send(dumps(state), copy=False)
c2s_socket.send('h')
#with total_timer('client recv_action'):
data = s2c_socket.recv(copy=False)
action = loads(data)
reward, isOver = player.action(action)
c2s_socket.send(dumps((reward, isOver)), copy=False)
#with total_timer('client recv_ack'):
ACK = s2c_socket.recv(copy=False)
#cnt += 1
#if cnt % 100 == 0:
......@@ -103,7 +106,7 @@ class SimulatorMaster(threading.Thread):
self.daemon = True
# queueing messages to client
self.send_queue = queue.Queue(maxsize=50)
self.send_queue = queue.Queue(maxsize=100)
self.send_thread = LoopThread(lambda:
self.s2c_socket.send_multipart(self.send_queue.get()))
self.send_thread.start()
......@@ -123,7 +126,8 @@ class SimulatorMaster(threading.Thread):
client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action
state = loads(msg)
#state = loads(msg)
state = np.zeros((84, 84, 4), dtype='float32')
self._on_state(state, ident)
else: # reward-response
reward, isOver = loads(msg)
......
......@@ -10,7 +10,7 @@ import uuid
import os
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_proc_terminate
from ..utils.concurrency import *
from ..utils.serialize import *
from ..utils import logger
......@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
for _ in range(self.nr_proc)]
for x in self.procs:
x.start()
start_proc_mask_signal(self.procs)
# __del__ not guranteed to get called at exit
import atexit
......
......@@ -81,42 +81,36 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
self.func = pred_func
self.daemon = True
self.batch_size = batch_size
self.id = id
def run(self):
#self.xxx = None
def fetch():
batched = []
futures = []
batched, futures = [], []
inp, f = self.queue.get()
batched.append(inp)
futures.append(f)
#print "func queue:", self.queue.qsize()
#return batched, futures
if self.batch_size == 1:
return batched, futures
while True:
try:
inp, f = self.queue.get_nowait()
batched.append(inp)
futures.append(f)
if len(batched) == 5:
if len(batched) == self.batch_size:
break
except queue.Empty:
break
return batched, futures
#self.xxx = None
while True:
# normal input
#inputs, f = self.queue.get()
#outputs = self.func(inputs)
#f.set_result(outputs)
batched, futures = fetch()
#print "batched size: ", len(batched)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched])
#if self.xxx is None:
#outputs = self.func([batched])
......@@ -134,13 +128,13 @@ class MultiThreadAsyncPredictor(object):
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
def __init__(self, trainer, input_names, output_names, nr_thread):
def __init__(self, trainer, input_names, output_names, nr_thread, batch_size=5):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [
PredictorWorkerThread(self.input_queue, f, id)
PredictorWorkerThread(self.input_queue, f, id, batch_size)
for id, f in enumerate(
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
......
......@@ -76,6 +76,7 @@ class EnqueueThread(threading.Thread):
self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
self.daemon = True
def run(self):
......@@ -86,6 +87,8 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed)
#print size
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
......@@ -102,7 +105,9 @@ class QueueInputTrainer(Trainer):
Support multi GPU.
"""
def __init__(self, config, input_queue=None, async=False):
def __init__(self, config, input_queue=None,
async=False,
predict_tower=None):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
......@@ -119,6 +124,10 @@ class QueueInputTrainer(Trainer):
if self.async:
assert self.config.nr_tower > 1
self.dequed_inputs = []
if predict_tower is None:
# by default, only use first training tower for prediction
predict_tower = [0]
self.predict_tower = predict_tower
@staticmethod
def _average_grads(tower_grads):
......@@ -144,6 +153,15 @@ class QueueInputTrainer(Trainer):
self.dequed_inputs.append(ret)
return ret
def _build_predict_tower(self):
inputs = self.model.get_input_vars()
for k in self.predict_tower:
logger.info("Building graph for predict tower 0{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \
tf.name_scope('tower0{}'.format(k)):
self.model.build_graph(inputs, False)
tf.get_variable_scope().reuse_variables()
def _single_tower_grad(self):
""" Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs()
......@@ -159,12 +177,14 @@ class QueueInputTrainer(Trainer):
# to avoid repeated summary from each device
collect_dedup = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for tower {}...".format(i))
logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
......@@ -186,6 +206,7 @@ class QueueInputTrainer(Trainer):
def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars)
self._build_predict_tower()
if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads()
if not self.async:
......@@ -219,10 +240,13 @@ class QueueInputTrainer(Trainer):
self.threads.append(th)
self.async_running = False
self.init_session_and_coord()
# create a thread that keeps filling the queue
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th)
# do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
def run_step(self):
......@@ -247,22 +271,18 @@ class QueueInputTrainer(Trainer):
"""
:param tower: return the kth predict_func
"""
tower = tower % self.config.nr_tower
tower = self.predict_tower[tower % len(self.predict_tower)]
if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower{} ...".format(tower))
logger.info("Prepare a predictor function for tower0{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
dequed = self.dequed_inputs[tower]
input_vars = [dequed[k] for k in input_var_idxs]
if self.config.nr_tower > 1:
output_names = ['tower{}/'.format(tower) + n for n in output_names]
output_names = ['tower0{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
assert len(inputs) == len(raw_input_vars)
feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
......
......@@ -37,7 +37,7 @@ def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec, {} times".format(
k, v.sum, v.count))
logger.info("Total Time: {} -> {} sec, {} times, {} sec/time".format(
k, v.sum, v.count, v.average))
atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment