Commit b7ee409b authored by Yuxin Wu's avatar Yuxin Wu

small changes in __main__

parent bc551406
......@@ -60,10 +60,6 @@ ENV_NAME = None
def get_player(viz=False, train=False, dumpdir=None):
pl = GymEnv(ENV_NAME, viz=viz, dumpdir=dumpdir)
pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1]))
global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions()
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
if not train:
pl = PreventStuckPlayer(pl, 30, 1)
......@@ -201,8 +197,6 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def get_config():
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
M = Model()
name_base = str(uuid.uuid1())[:6]
......@@ -251,17 +245,15 @@ if __name__ == '__main__':
args = parser.parse_args()
ENV_NAME = args.env
assert ENV_NAME
logger.info("Environment Name: {}".format(ENV_NAME))
p = get_player()
del p # set NUM_ACTIONS
NUM_ACTIONS = get_player().get_action_space().num_actions()
logger.info("Number of actions: {}".format(NUM_ACTIONS))
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.task != 'train':
assert args.load is not None
if args.task != 'train':
assert args.load is not None
cfg = PredictConfig(
model=Model(),
session_init=get_model_loader(args.load),
......@@ -277,7 +269,11 @@ if __name__ == '__main__':
OfflinePredictor(cfg), args.episode)
# gym.upload(output, api_key='xxx')
else:
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
nr_gpu = get_nr_gpu()
trainer = QueueInputTrainer
if nr_gpu > 0:
if nr_gpu > 1:
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
......@@ -285,12 +281,12 @@ if __name__ == '__main__':
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
if len(train_tower) > 1:
trainer = AsyncMultiGPUTrainer
else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
nr_gpu = 0
PREDICTOR_THREAD = 1
predict_tower, train_tower = [0], [0]
trainer = QueueInputTrainer
......
......@@ -149,17 +149,14 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.task != 'train':
assert args.load is not None
ROM_FILE = args.rom
METHOD = args.algo
# set num_actions
pl = AtariPlayer(ROM_FILE, viz=False)
NUM_ACTIONS = pl.get_action_space().num_actions()
del pl
NUM_ACTIONS = AtariPlayer(ROM_FILE).get_action_space().num_actions()
logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))
if args.task != 'train':
assert args.load is not None
cfg = PredictConfig(
model=Model(),
session_init=get_model_loader(args.load),
......@@ -171,8 +168,8 @@ if __name__ == '__main__':
eval_model_multithread(cfg, EVAL_EPISODE, get_player)
else:
logger.set_logger_dir(
'train_log/DQN-{}'.format(
os.path.basename(ROM_FILE).split('.')[0]))
os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(ROM_FILE).split('.')[0])))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
......
......@@ -112,8 +112,6 @@ def get_data(train_or_test, cifar_classnum):
def get_config(cifar_classnum):
logger.auto_set_dir()
# prepare dataset
dataset_train = get_data('train', cifar_classnum)
dataset_test = get_data('test', cifar_classnum)
......@@ -145,10 +143,9 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default():
logger.set_logger_dir(os.path.join('train_log', 'cifar' + str(args.classnum)))
config = get_config(args.classnum)
if args.load:
config.session_init = SaverRestore(args.load)
......@@ -156,7 +153,7 @@ if __name__ == '__main__':
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
nr_gpu = get_nr_gpu()
if nr_gpu == 1:
if nr_gpu <= 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainer(config).train()
......@@ -102,9 +102,6 @@ def get_data():
def get_config():
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
......@@ -136,9 +133,12 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo.
SimpleTrainer(config).train()
# You can use QueueInputTrainer instead
SimpleTrainer(config).train()
......@@ -94,7 +94,6 @@ def get_data():
def get_config():
logger.auto_set_dir()
data_train, data_test = get_data()
return TrainConfig(
......@@ -120,6 +119,7 @@ if __name__ == '__main__':
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
logger.auto_set_dir()
with tf.Graph().as_default():
config = get_config()
if args.load:
......
......@@ -40,6 +40,9 @@ class HistoryBuffer(object):
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images.
It assumes the underlying player will do auto-restart.
Map the original frames into (H, W, HIST x channels).
Oldest frames first.
"""
def __init__(self, player, hist_len):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment