Commit 80723d31 authored by Yuxin Wu's avatar Yuxin Wu

gpu eval & one_hot

parent f15c2181
...@@ -22,7 +22,7 @@ from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPred ...@@ -22,7 +22,7 @@ from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPred
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.RL import AtariPlayer, ExpReplay from tensorpack.RL import *
""" """
Implement DQN in: Implement DQN in:
...@@ -43,7 +43,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.008 ...@@ -43,7 +43,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.008
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = 50000 INIT_MEMORY_SIZE = 500
STEP_PER_EPOCH = 10000 STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100 EVAL_EPISODE = 100
...@@ -86,7 +86,7 @@ class Model(ModelDesc): ...@@ -86,7 +86,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs, is_training):
state, action, reward, next_state, isOver = inputs state, action, reward, next_state, isOver = inputs
self.predict_value = self._get_DQN_prediction(state, is_training) self.predict_value = self._get_DQN_prediction(state, is_training)
action_onehot = symbf.one_hot(action, NUM_ACTIONS) action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #Nx1 pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #Nx1
max_pred_reward = tf.reduce_mean(tf.reduce_max( max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
...@@ -128,7 +128,6 @@ def current_predictor(state): ...@@ -128,7 +128,6 @@ def current_predictor(state):
def play_one_episode(player, func, verbose=False): def play_one_episode(player, func, verbose=False):
tot_reward = 0 tot_reward = 0
que = deque(maxlen=30)
while True: while True:
s = player.current_state() s = player.current_state()
outputs = func([[s]]) outputs = func([[s]])
...@@ -138,10 +137,6 @@ def play_one_episode(player, func, verbose=False): ...@@ -138,10 +137,6 @@ def play_one_episode(player, func, verbose=False):
print action_value, act print action_value, act
if random.random() < 0.01: if random.random() < 0.01:
act = random.choice(range(NUM_ACTIONS)) act = random.choice(range(NUM_ACTIONS))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1 # hack, avoid stuck
que.append(act)
if verbose: if verbose:
print(act) print(act)
reward, isOver = player.action(act) reward, isOver = player.action(act)
...@@ -150,7 +145,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -150,7 +145,7 @@ def play_one_episode(player, func, verbose=False):
return tot_reward return tot_reward
def play_model(model_path): def play_model(model_path):
player = HistoryFramePlayer(get_player(0.01), FRAME_HISTORY) player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.01), FRAME_HISTORY), 30, 1)
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
...@@ -162,9 +157,8 @@ def play_model(model_path): ...@@ -162,9 +157,8 @@ def play_model(model_path):
print("Total:", score) print("Total:", score)
def eval_model_multiprocess(model_path): def eval_model_multiprocess(model_path):
M = Model()
cfg = PredictConfig( cfg = PredictConfig(
model=M, model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
session_init=SaverRestore(model_path), session_init=SaverRestore(model_path),
output_var_names=['fct/output:0']) output_var_names=['fct/output:0'])
...@@ -175,17 +169,16 @@ def eval_model_multiprocess(model_path): ...@@ -175,17 +169,16 @@ def eval_model_multiprocess(model_path):
self.outq = outqueue self.outq = outqueue
def run(self): def run(self):
player = HistoryFramePlayer(get_player(), FRAME_HISTORY) player = PreventStuckPlayer(HistoryFramePlayer(get_player(), FRAME_HISTORY), 30, 1)
self._init_runtime() self._init_runtime()
while True: while True:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
self.outq.put(score) self.outq.put(score)
NR_PROC = min(multiprocessing.cpu_count() // 2, 10) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
procs = []
q = multiprocessing.Queue() q = multiprocessing.Queue()
for k in range(NR_PROC): gpuid = get_gpus()[0]
procs.append(Worker(k, -1, cfg, q)) procs = [Worker(k, gpuid, cfg, q) for k in range(NR_PROC)]
ensure_proc_terminate(procs) ensure_proc_terminate(procs)
for k in procs: for k in procs:
k.start() k.start()
...@@ -202,8 +195,8 @@ class Evaluator(Callback): ...@@ -202,8 +195,8 @@ class Evaluator(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
logger.info("Evaluating...") logger.info("Evaluating...")
output = subproc_call( output = subproc_call(
"CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {}".format( "{} --task eval --rom {} --load {}".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), sys.argv[0], ROM_FILE, os.path.join(logger.LOG_DIR, 'checkpoint')),
timeout=10*60) timeout=10*60)
if output: if output:
last = output.strip().split('\n')[-1] last = output.strip().split('\n')[-1]
...@@ -246,6 +239,8 @@ def get_config(): ...@@ -246,6 +239,8 @@ def get_config():
dataset_train, dataset_train,
PeriodicCallback(Evaluator(), 2), PeriodicCallback(Evaluator(), 2),
]), ]),
# save memory for multiprocess evaluator
session_config=get_default_sess_config(0.3),
model=M, model=M,
step_per_epoch=STEP_PER_EPOCH, step_per_epoch=STEP_PER_EPOCH,
) )
......
...@@ -78,7 +78,8 @@ def get_predict_func(config): ...@@ -78,7 +78,8 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names] for n in output_var_names]
sess = tf.Session() # start with minimal memory, but allow growth
sess = tf.Session(config=get_default_sess_config(0.01))
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from ..utils import logger
def one_hot(y, num_labels): def one_hot(y, num_labels):
""" """
...@@ -11,15 +12,8 @@ def one_hot(y, num_labels): ...@@ -11,15 +12,8 @@ def one_hot(y, num_labels):
:param num_labels: an int. number of output classes :param num_labels: an int. number of output classes
:returns: an NxC onehot matrix. :returns: an NxC onehot matrix.
""" """
with tf.op_scope([y, num_labels], 'one_hot'): logger.warn("symbf.one_hot is deprecated in favor of more general tf.one_hot")
batch_size = tf.size(y) return tf.one_hot(y, num_labels, 1.0, 0.0, name='one_hot')
y = tf.expand_dims(y, 1)
indices = tf.expand_dims(tf.range(0, batch_size), 1)
concated = tf.concat(1, [indices, y])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, num_labels]), 1.0, 0.0)
onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32)
def prediction_incorrect(logits, label, topk=1): def prediction_incorrect(logits, label, topk=1):
""" """
......
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
from . import logger from . import logger
__all__ = ['timed_operation', 'change_env', __all__ = ['timed_operation', 'change_env',
'get_rng', 'memoized', 'get_nr_gpu'] 'get_rng', 'memoized', 'get_nr_gpu', 'get_gpus']
#def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
# """ # """
...@@ -83,5 +83,10 @@ def get_rng(self): ...@@ -83,5 +83,10 @@ def get_rng(self):
def get_nr_gpu(): def get_nr_gpu():
env = os.environ['CUDA_VISIBLE_DEVICES'] env = os.environ['CUDA_VISIBLE_DEVICES']
assert env is not None assert env is not None # TODO
return len(env.split(',')) return len(env.split(','))
def get_gpus():
env = os.environ['CUDA_VISIBLE_DEVICES']
assert env is not None # TODO
return map(int, env.strip().split(','))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment