Commit b61ba3c9 authored by Yuxin Wu's avatar Yuxin Wu

better proc mask

parent 208de18c
......@@ -31,10 +31,12 @@ for atari games
"""
BATCH_SIZE = 32
IMAGE_SIZE = 84
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout
CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
#HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99
......@@ -51,7 +53,7 @@ NUM_ACTIONS = None
ROM_FILE = None
def get_player(viz=False, train=False):
player = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE,
pl = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz,
live_lost_as_eoe=train)
global NUM_ACTIONS
......@@ -61,10 +63,10 @@ def get_player(viz=False, train=False):
class Model(ModelDesc):
def _get_input_vars(self):
assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, FRAME_HISTORY), 'state'),
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int32, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'),
InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, FRAME_HISTORY), 'next_state'),
InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputVar(tf.bool, (None,), 'isOver') ]
def _get_DQN_prediction(self, image, is_training):
......
......@@ -100,7 +100,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 10, 5)
ds = PrefetchDataZMQ(ds, 5)
return ds
def get_config(cifar_classnum):
......@@ -156,5 +156,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
#QueueInputTrainer(config).train()
SimpleTrainer(config).train()
QueueInputTrainer(config).train()
#SimpleTrainer(config).train()
......@@ -17,7 +17,7 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster']
class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """
__meta__ = ABCMeta
__metaclass__ = ABCMeta
def __init__(self, idx, server_name):
"""
......
......@@ -11,6 +11,7 @@ import tqdm
import tensorflow as tf
from .config import TrainConfig
from ..utils import *
from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder
from ..tfutils import *
from ..tfutils.summary import create_summary
......@@ -151,10 +152,7 @@ class Trainer(object):
sess=self.sess, coord=self.coord, daemon=True, start=True)
# avoid sigint get handled by other processes
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
for k in self.extra_threads_procs:
k.start()
signal.signal(signal.SIGINT, sigint_handler)
start_proc_mask_signal(self.extra_threads_procs)
def process_grads(self, grads):
......
......@@ -246,6 +246,7 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func
"""
tower = tower % self.config.nr_tower
if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
......
......@@ -7,6 +7,7 @@ import threading
import multiprocessing
import atexit
import bisect
import signal
import weakref
import six
if six.PY2:
......@@ -17,7 +18,8 @@ else:
from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'start_proc_mask_signal']
class StoppableThread(threading.Thread):
def __init__(self):
......@@ -76,6 +78,15 @@ def ensure_proc_terminate(proc):
assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def start_proc_mask_signal(proc):
if not isinstance(proc, list):
proc = [proc]
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
for p in proc:
p.start()
signal.signal(signal.SIGINT, sigint_handler)
def subproc_call(cmd, timeout=None):
try:
output = subprocess.check_output(
......@@ -117,7 +128,6 @@ class OrderedContainer(object):
self.wait_for += 1
return rank, ret
class OrderedResultGatherProc(multiprocessing.Process):
"""
Gather indexed data from a data queue, and produce results with the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment