Commit 3a431489 authored by Yuxin Wu's avatar Yuxin Wu

breakout img

parent e1c0102e
![breakout](breakout.jpg)
Reproduce the following reinforcement learning methods: Reproduce the following reinforcement learning methods:
+ Nature-DQN in: + Nature-DQN in:
......
...@@ -87,17 +87,17 @@ class OfflinePredictor(OnlinePredictor): ...@@ -87,17 +87,17 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
class AsyncOnlinePredictor(PredictorBase): #class AsyncOnlinePredictor(PredictorBase):
def __init__(self, sess, enqueue_op, output_vars, return_input=False): #def __init__(self, sess, enqueue_op, output_vars, return_input=False):
""" #"""
:param enqueue_op: an op to feed inputs with. #:param enqueue_op: an op to feed inputs with.
:param output_vars: a list of directly-runnable (no extra feeding requirements) #:param output_vars: a list of directly-runnable (no extra feeding requirements)
vars producing the outputs. #vars producing the outputs.
""" #"""
self.session = sess #self.session = sess
self.enqop = enqueue_op #self.enqop = enqueue_op
self.output_vars = output_vars #self.output_vars = output_vars
self.return_input = return_input #self.return_input = return_input
def put_task(self, dp, callback): #def put_task(self, dp, callback):
pass #pass
...@@ -107,7 +107,7 @@ class Trainer(object): ...@@ -107,7 +107,7 @@ class Trainer(object):
get_global_step_var() # ensure there is such var, before finalizing the graph get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.setup_graph(self) callbacks.setup_graph(self) # TODO use weakref instead?
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables()) self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment