Commit b61ba3c9 authored by Yuxin Wu's avatar Yuxin Wu

better proc mask

parent 208de18c
...@@ -31,10 +31,12 @@ for atari games ...@@ -31,10 +31,12 @@ for atari games
""" """
BATCH_SIZE = 32 BATCH_SIZE = 32
IMAGE_SIZE = 84 IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout HEIGHT_RANGE = (36, 204) # for breakout
CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
#HEIGHT_RANGE = (28, -8) # for pong #HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99 GAMMA = 0.99
...@@ -51,7 +53,7 @@ NUM_ACTIONS = None ...@@ -51,7 +53,7 @@ NUM_ACTIONS = None
ROM_FILE = None ROM_FILE = None
def get_player(viz=False, train=False): def get_player(viz=False, train=False):
player = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE, pl = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz, frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz,
live_lost_as_eoe=train) live_lost_as_eoe=train)
global NUM_ACTIONS global NUM_ACTIONS
...@@ -61,10 +63,10 @@ def get_player(viz=False, train=False): ...@@ -61,10 +63,10 @@ def get_player(viz=False, train=False):
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, FRAME_HISTORY), 'state'), return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int32, (None,), 'action'), InputVar(tf.int32, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'), InputVar(tf.float32, (None,), 'reward'),
InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, FRAME_HISTORY), 'next_state'), InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputVar(tf.bool, (None,), 'isOver') ] InputVar(tf.bool, (None,), 'isOver') ]
def _get_DQN_prediction(self, image, is_training): def _get_DQN_prediction(self, image, is_training):
......
...@@ -100,7 +100,7 @@ def get_data(train_or_test, cifar_classnum): ...@@ -100,7 +100,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 10, 5) ds = PrefetchDataZMQ(ds, 5)
return ds return ds
def get_config(cifar_classnum): def get_config(cifar_classnum):
...@@ -156,5 +156,5 @@ if __name__ == '__main__': ...@@ -156,5 +156,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
#QueueInputTrainer(config).train() QueueInputTrainer(config).train()
SimpleTrainer(config).train() #SimpleTrainer(config).train()
...@@ -17,7 +17,7 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster'] ...@@ -17,7 +17,7 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster']
class SimulatorProcess(multiprocessing.Process): class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """ """ A process that simulates a player """
__meta__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self, idx, server_name): def __init__(self, idx, server_name):
""" """
......
...@@ -11,6 +11,7 @@ import tqdm ...@@ -11,6 +11,7 @@ import tqdm
import tensorflow as tf import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import * from ..utils import *
from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
...@@ -151,10 +152,7 @@ class Trainer(object): ...@@ -151,10 +152,7 @@ class Trainer(object):
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) start_proc_mask_signal(self.extra_threads_procs)
for k in self.extra_threads_procs:
k.start()
signal.signal(signal.SIGINT, sigint_handler)
def process_grads(self, grads): def process_grads(self, grads):
......
...@@ -246,7 +246,8 @@ class QueueInputTrainer(Trainer): ...@@ -246,7 +246,8 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func :param tower: return the kth predict_func
""" """
tower = tower % self.config.nr_tower tower = tower % self.config.nr_tower
logger.info("Prepare a predictor function for tower{} ...".format(tower)) if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars] input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
......
...@@ -7,6 +7,7 @@ import threading ...@@ -7,6 +7,7 @@ import threading
import multiprocessing import multiprocessing
import atexit import atexit
import bisect import bisect
import signal
import weakref import weakref
import six import six
if six.PY2: if six.PY2:
...@@ -17,7 +18,8 @@ else: ...@@ -17,7 +18,8 @@ else:
from . import logger from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE'] 'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'start_proc_mask_signal']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
def __init__(self): def __init__(self):
...@@ -76,6 +78,15 @@ def ensure_proc_terminate(proc): ...@@ -76,6 +78,15 @@ def ensure_proc_terminate(proc):
assert isinstance(proc, multiprocessing.Process) assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def start_proc_mask_signal(proc):
if not isinstance(proc, list):
proc = [proc]
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
for p in proc:
p.start()
signal.signal(signal.SIGINT, sigint_handler)
def subproc_call(cmd, timeout=None): def subproc_call(cmd, timeout=None):
try: try:
output = subprocess.check_output( output = subprocess.check_output(
...@@ -117,7 +128,6 @@ class OrderedContainer(object): ...@@ -117,7 +128,6 @@ class OrderedContainer(object):
self.wait_for += 1 self.wait_for += 1
return rank, ret return rank, ret
class OrderedResultGatherProc(multiprocessing.Process): class OrderedResultGatherProc(multiprocessing.Process):
""" """
Gather indexed data from a data queue, and produce results with the Gather indexed data from a data queue, and produce results with the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment