Commit b7ee409b authored by Yuxin Wu's avatar Yuxin Wu

small changes in __main__

parent bc551406
...@@ -60,10 +60,6 @@ ENV_NAME = None ...@@ -60,10 +60,6 @@ ENV_NAME = None
def get_player(viz=False, train=False, dumpdir=None): def get_player(viz=False, train=False, dumpdir=None):
pl = GymEnv(ENV_NAME, viz=viz, dumpdir=dumpdir) pl = GymEnv(ENV_NAME, viz=viz, dumpdir=dumpdir)
pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1])) pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1]))
global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions()
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
if not train: if not train:
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
...@@ -201,8 +197,6 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -201,8 +197,6 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def get_config(): def get_config():
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
M = Model() M = Model()
name_base = str(uuid.uuid1())[:6] name_base = str(uuid.uuid1())[:6]
...@@ -251,17 +245,15 @@ if __name__ == '__main__': ...@@ -251,17 +245,15 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
ENV_NAME = args.env ENV_NAME = args.env
assert ENV_NAME
logger.info("Environment Name: {}".format(ENV_NAME)) logger.info("Environment Name: {}".format(ENV_NAME))
p = get_player() NUM_ACTIONS = get_player().get_action_space().num_actions()
del p # set NUM_ACTIONS logger.info("Number of actions: {}".format(NUM_ACTIONS))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.task != 'train':
assert args.load is not None
if args.task != 'train': if args.task != 'train':
assert args.load is not None
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(args.load), session_init=get_model_loader(args.load),
...@@ -277,7 +269,11 @@ if __name__ == '__main__': ...@@ -277,7 +269,11 @@ if __name__ == '__main__':
OfflinePredictor(cfg), args.episode) OfflinePredictor(cfg), args.episode)
# gym.upload(output, api_key='xxx') # gym.upload(output, api_key='xxx')
else: else:
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
trainer = QueueInputTrainer
if nr_gpu > 0: if nr_gpu > 0:
if nr_gpu > 1: if nr_gpu > 1:
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:] predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
...@@ -285,12 +281,12 @@ if __name__ == '__main__': ...@@ -285,12 +281,12 @@ if __name__ == '__main__':
predict_tower = [0] predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0] train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format( logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower)))) ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
trainer = AsyncMultiGPUTrainer if len(train_tower) > 1:
trainer = AsyncMultiGPUTrainer
else: else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.") logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
nr_gpu = 0
PREDICTOR_THREAD = 1 PREDICTOR_THREAD = 1
predict_tower, train_tower = [0], [0] predict_tower, train_tower = [0], [0]
trainer = QueueInputTrainer trainer = QueueInputTrainer
......
...@@ -149,17 +149,14 @@ if __name__ == '__main__': ...@@ -149,17 +149,14 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.task != 'train':
assert args.load is not None
ROM_FILE = args.rom ROM_FILE = args.rom
METHOD = args.algo METHOD = args.algo
# set num_actions # set num_actions
pl = AtariPlayer(ROM_FILE, viz=False) NUM_ACTIONS = AtariPlayer(ROM_FILE).get_action_space().num_actions()
NUM_ACTIONS = pl.get_action_space().num_actions() logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))
del pl
if args.task != 'train': if args.task != 'train':
assert args.load is not None
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(args.load), session_init=get_model_loader(args.load),
...@@ -171,8 +168,8 @@ if __name__ == '__main__': ...@@ -171,8 +168,8 @@ if __name__ == '__main__':
eval_model_multithread(cfg, EVAL_EPISODE, get_player) eval_model_multithread(cfg, EVAL_EPISODE, get_player)
else: else:
logger.set_logger_dir( logger.set_logger_dir(
'train_log/DQN-{}'.format( os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(ROM_FILE).split('.')[0])) os.path.basename(ROM_FILE).split('.')[0])))
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -112,8 +112,6 @@ def get_data(train_or_test, cifar_classnum): ...@@ -112,8 +112,6 @@ def get_data(train_or_test, cifar_classnum):
def get_config(cifar_classnum): def get_config(cifar_classnum):
logger.auto_set_dir()
# prepare dataset # prepare dataset
dataset_train = get_data('train', cifar_classnum) dataset_train = get_data('train', cifar_classnum)
dataset_test = get_data('test', cifar_classnum) dataset_test = get_data('test', cifar_classnum)
...@@ -145,10 +143,9 @@ if __name__ == '__main__': ...@@ -145,10 +143,9 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default(): with tf.Graph().as_default():
logger.set_logger_dir(os.path.join('train_log', 'cifar' + str(args.classnum)))
config = get_config(args.classnum) config = get_config(args.classnum)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
...@@ -156,7 +153,7 @@ if __name__ == '__main__': ...@@ -156,7 +153,7 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu == 1: if nr_gpu <= 1:
QueueInputTrainer(config).train() QueueInputTrainer(config).train()
else: else:
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -102,9 +102,6 @@ def get_data(): ...@@ -102,9 +102,6 @@ def get_data():
def get_config(): def get_config():
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch. # How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config # This is the default value, don't actually need to set it in the config
...@@ -136,9 +133,12 @@ if __name__ == '__main__': ...@@ -136,9 +133,12 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo. # SimpleTrainer is slow, this is just a demo.
SimpleTrainer(config).train()
# You can use QueueInputTrainer instead # You can use QueueInputTrainer instead
SimpleTrainer(config).train()
...@@ -94,7 +94,6 @@ def get_data(): ...@@ -94,7 +94,6 @@ def get_data():
def get_config(): def get_config():
logger.auto_set_dir()
data_train, data_test = get_data() data_train, data_test = get_data()
return TrainConfig( return TrainConfig(
...@@ -120,6 +119,7 @@ if __name__ == '__main__': ...@@ -120,6 +119,7 @@ if __name__ == '__main__':
else: else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
logger.auto_set_dir()
with tf.Graph().as_default(): with tf.Graph().as_default():
config = get_config() config = get_config()
if args.load: if args.load:
......
...@@ -40,6 +40,9 @@ class HistoryBuffer(object): ...@@ -40,6 +40,9 @@ class HistoryBuffer(object):
class HistoryFramePlayer(ProxyPlayer): class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images. """ Include history frames in state, or use black images.
It assumes the underlying player will do auto-restart. It assumes the underlying player will do auto-restart.
Map the original frames into (H, W, HIST x channels).
Oldest frames first.
""" """
def __init__(self, player, hist_len): def __init__(self, player, hist_len):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment