Commit 126fd26a authored by Yuxin Wu's avatar Yuxin Wu

[WIP] queuetrainer + multigputrainer, predict tower

parent 254af98b
...@@ -9,6 +9,7 @@ import threading ...@@ -9,6 +9,7 @@ import threading
import weakref import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
import numpy as np
from six.moves import queue from six.moves import queue
from ..utils.timer import * from ..utils.timer import *
...@@ -53,12 +54,14 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -53,12 +54,14 @@ class SimulatorProcess(multiprocessing.Process):
#cnt = 0 #cnt = 0
while True: while True:
state = player.current_state() state = player.current_state()
c2s_socket.send(dumps(state), copy=False) #c2s_socket.send(dumps(state), copy=False)
c2s_socket.send('h')
#with total_timer('client recv_action'): #with total_timer('client recv_action'):
data = s2c_socket.recv(copy=False) data = s2c_socket.recv(copy=False)
action = loads(data) action = loads(data)
reward, isOver = player.action(action) reward, isOver = player.action(action)
c2s_socket.send(dumps((reward, isOver)), copy=False) c2s_socket.send(dumps((reward, isOver)), copy=False)
#with total_timer('client recv_ack'):
ACK = s2c_socket.recv(copy=False) ACK = s2c_socket.recv(copy=False)
#cnt += 1 #cnt += 1
#if cnt % 100 == 0: #if cnt % 100 == 0:
...@@ -103,7 +106,7 @@ class SimulatorMaster(threading.Thread): ...@@ -103,7 +106,7 @@ class SimulatorMaster(threading.Thread):
self.daemon = True self.daemon = True
# queueing messages to client # queueing messages to client
self.send_queue = queue.Queue(maxsize=50) self.send_queue = queue.Queue(maxsize=100)
self.send_thread = LoopThread(lambda: self.send_thread = LoopThread(lambda:
self.s2c_socket.send_multipart(self.send_queue.get())) self.s2c_socket.send_multipart(self.send_queue.get()))
self.send_thread.start() self.send_thread.start()
...@@ -123,7 +126,8 @@ class SimulatorMaster(threading.Thread): ...@@ -123,7 +126,8 @@ class SimulatorMaster(threading.Thread):
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action if not client.protocol_state == 0: # state-action
state = loads(msg) #state = loads(msg)
state = np.zeros((84, 84, 4), dtype='float32')
self._on_state(state, ident) self._on_state(state, ident)
else: # reward-response else: # reward-response
reward, isOver = loads(msg) reward, isOver = loads(msg)
......
...@@ -10,7 +10,7 @@ import uuid ...@@ -10,7 +10,7 @@ import uuid
import os import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import ensure_proc_terminate from ..utils.concurrency import *
from ..utils.serialize import * from ..utils.serialize import *
from ..utils import logger from ..utils import logger
...@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename) self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
for x in self.procs: start_proc_mask_signal(self.procs)
x.start()
# __del__ not guranteed to get called at exit # __del__ not guranteed to get called at exit
import atexit import atexit
......
...@@ -81,42 +81,36 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -81,42 +81,36 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
self.func = pred_func self.func = pred_func
self.daemon = True self.daemon = True
self.batch_size = batch_size
self.id = id self.id = id
def run(self): def run(self):
#self.xxx = None
def fetch(): def fetch():
batched = [] batched, futures = [], []
futures = []
inp, f = self.queue.get() inp, f = self.queue.get()
batched.append(inp) batched.append(inp)
futures.append(f) futures.append(f)
#print "func queue:", self.queue.qsize() if self.batch_size == 1:
#return batched, futures return batched, futures
while True: while True:
try: try:
inp, f = self.queue.get_nowait() inp, f = self.queue.get_nowait()
batched.append(inp) batched.append(inp)
futures.append(f) futures.append(f)
if len(batched) == 5: if len(batched) == self.batch_size:
break break
except queue.Empty: except queue.Empty:
break break
return batched, futures return batched, futures
#self.xxx = None #self.xxx = None
while True: while True:
# normal input
#inputs, f = self.queue.get()
#outputs = self.func(inputs)
#f.set_result(outputs)
batched, futures = fetch() batched, futures = fetch()
#print "batched size: ", len(batched) #print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched]) outputs = self.func([batched])
#if self.xxx is None: #if self.xxx is None:
#outputs = self.func([batched]) #outputs = self.func([batched])
...@@ -134,13 +128,13 @@ class MultiThreadAsyncPredictor(object): ...@@ -134,13 +128,13 @@ class MultiThreadAsyncPredictor(object):
An online predictor (use the current active session) that works with An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU. QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
""" """
def __init__(self, trainer, input_names, output_names, nr_thread): def __init__(self, trainer, input_names, output_names, nr_thread, batch_size=5):
""" """
:param trainer: a `QueueInputTrainer` instance. :param trainer: a `QueueInputTrainer` instance.
""" """
self.input_queue = queue.Queue(maxsize=nr_thread*10) self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [ self.threads = [
PredictorWorkerThread(self.input_queue, f, id) PredictorWorkerThread(self.input_queue, f, id, batch_size)
for id, f in enumerate( for id, f in enumerate(
trainer.get_predict_funcs( trainer.get_predict_funcs(
input_names, output_names, nr_thread))] input_names, output_names, nr_thread))]
......
...@@ -76,6 +76,7 @@ class EnqueueThread(threading.Thread): ...@@ -76,6 +76,7 @@ class EnqueueThread(threading.Thread):
self.queue = queue self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
self.daemon = True self.daemon = True
def run(self): def run(self):
...@@ -86,6 +87,8 @@ class EnqueueThread(threading.Thread): ...@@ -86,6 +87,8 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed)
#print size
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
...@@ -102,7 +105,9 @@ class QueueInputTrainer(Trainer): ...@@ -102,7 +105,9 @@ class QueueInputTrainer(Trainer):
Support multi GPU. Support multi GPU.
""" """
def __init__(self, config, input_queue=None, async=False): def __init__(self, config, input_queue=None,
async=False,
predict_tower=None):
""" """
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
...@@ -119,6 +124,10 @@ class QueueInputTrainer(Trainer): ...@@ -119,6 +124,10 @@ class QueueInputTrainer(Trainer):
if self.async: if self.async:
assert self.config.nr_tower > 1 assert self.config.nr_tower > 1
self.dequed_inputs = [] self.dequed_inputs = []
if predict_tower is None:
# by default, only use first training tower for prediction
predict_tower = [0]
self.predict_tower = predict_tower
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -144,6 +153,15 @@ class QueueInputTrainer(Trainer): ...@@ -144,6 +153,15 @@ class QueueInputTrainer(Trainer):
self.dequed_inputs.append(ret) self.dequed_inputs.append(ret)
return ret return ret
def _build_predict_tower(self):
inputs = self.model.get_input_vars()
for k in self.predict_tower:
logger.info("Building graph for predict tower 0{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \
tf.name_scope('tower0{}'.format(k)):
self.model.build_graph(inputs, False)
tf.get_variable_scope().reuse_variables()
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower case""" """ Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs() model_inputs = self._get_model_inputs()
...@@ -159,12 +177,14 @@ class QueueInputTrainer(Trainer): ...@@ -159,12 +177,14 @@ class QueueInputTrainer(Trainer):
# to avoid repeated summary from each device # to avoid repeated summary from each device
collect_dedup = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] collect_dedup = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {} kept_summaries = {}
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
grad_list = [] grad_list = []
for i in range(self.config.nr_tower): for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for tower {}...".format(i)) logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.model.build_graph(model_inputs, True) self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower cost_var = self.model.get_cost() # build tower
...@@ -186,6 +206,7 @@ class QueueInputTrainer(Trainer): ...@@ -186,6 +206,7 @@ class QueueInputTrainer(Trainer):
def train(self): def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars) enqueue_op = self.input_queue.enqueue(self.input_vars)
self._build_predict_tower()
if self.config.nr_tower > 1: if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
if not self.async: if not self.async:
...@@ -219,10 +240,13 @@ class QueueInputTrainer(Trainer): ...@@ -219,10 +240,13 @@ class QueueInputTrainer(Trainer):
self.threads.append(th) self.threads.append(th)
self.async_running = False self.async_running = False
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars) self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th) self.extra_threads_procs.append(self.input_th)
# do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop() self.main_loop()
def run_step(self): def run_step(self):
...@@ -247,22 +271,18 @@ class QueueInputTrainer(Trainer): ...@@ -247,22 +271,18 @@ class QueueInputTrainer(Trainer):
""" """
:param tower: return the kth predict_func :param tower: return the kth predict_func
""" """
tower = tower % self.config.nr_tower tower = self.predict_tower[tower % len(self.predict_tower)]
if self.config.nr_tower > 1: if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower{} ...".format(tower)) logger.info("Prepare a predictor function for tower0{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
dequed = self.dequed_inputs[tower]
input_vars = [dequed[k] for k in input_var_idxs]
if self.config.nr_tower > 1: if self.config.nr_tower > 1:
output_names = ['tower{}/'.format(tower) + n for n in output_names] output_names = ['tower0{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names) output_vars = get_vars_by_names(output_names)
def func(inputs): def func(inputs):
assert len(inputs) == len(input_vars) assert len(inputs) == len(raw_input_vars)
feed = dict(zip(input_vars, inputs)) feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed) return self.sess.run(output_vars, feed_dict=feed)
return func return func
......
...@@ -37,7 +37,7 @@ def print_total_timer(): ...@@ -37,7 +37,7 @@ def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0: if len(_TOTAL_TIMER_DATA) == 0:
return return
for k, v in six.iteritems(_TOTAL_TIMER_DATA): for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec, {} times".format( logger.info("Total Time: {} -> {} sec, {} times, {} sec/time".format(
k, v.sum, v.count)) k, v.sum, v.count, v.average))
atexit.register(print_total_timer) atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment