Commit 1555899d authored by Yuxin Wu's avatar Yuxin Wu

more general atari/common

parent 943b1701
...@@ -171,7 +171,7 @@ def get_config(): ...@@ -171,7 +171,7 @@ def get_config():
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
dataset_train, dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, 'fct/output:0'), 2),
]), ]),
# save memory for multiprocess evaluator # save memory for multiprocess evaluator
session_config=get_default_sess_config(0.3), session_config=get_default_sess_config(0.3),
...@@ -194,10 +194,15 @@ if __name__ == '__main__': ...@@ -194,10 +194,15 @@ if __name__ == '__main__':
assert args.load is not None assert args.load is not None
ROM_FILE = args.rom ROM_FILE = args.rom
if args.task != 'train':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
output_var_names=['fct/output:0'])
if args.task == 'play': if args.task == 'play':
play_model(Model(), args.load) play_model(cfg)
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(Model(), args.load, EVAL_EPISODE) eval_model_multithread(cfg, EVAL_EPISODE)
else: else:
config = get_config() config = get_config()
if args.load: if args.load:
......
...@@ -28,13 +28,8 @@ def play_one_episode(player, func, verbose=False): ...@@ -28,13 +28,8 @@ def play_one_episode(player, func, verbose=False):
return act return act
return np.mean(player.play_one_episode(f)) return np.mean(player.play_one_episode(f))
def play_model(M, model_path): def play_model(cfg):
player = get_player(viz=0.01) player = get_player(viz=0.01)
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg) predfunc = get_predict_func(cfg)
while True: while True:
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
...@@ -73,25 +68,21 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -73,25 +68,21 @@ def eval_with_funcs(predict_funcs, nr_eval):
return (stat.average, stat.max) return (stat.average, stat.max)
return (0, 0) return (0, 0)
def eval_model_multithread(M, model_path, nr_eval): def eval_model_multithread(cfg, nr_eval):
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
func = get_predict_func(cfg) func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval) mean, max = eval_with_funcs([func] * NR_PROC, nr_eval)
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback): class Evaluator(Callback):
def __init__(self, nr_eval): def __init__(self, nr_eval, output_name):
self.eval_episode = nr_eval self.eval_episode = nr_eval
self.output_name = output_name
def _before_train(self): def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func( self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC ['state'], [self.output_name])] * NR_PROC
def _trigger_epoch(self): def _trigger_epoch(self):
t = time.time() t = time.time()
......
...@@ -84,9 +84,6 @@ def get_predict_func(config): ...@@ -84,9 +84,6 @@ def get_predict_func(config):
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
assert len(input_map) == len(dp), \
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp)) feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed) return sess.run(output_vars, feed_dict=feed)
return run_input return run_input
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment