Commit e309e627 authored by Yuxin Wu's avatar Yuxin Wu

update gym env

parent 17863b8b
...@@ -49,5 +49,7 @@ with tf.Graph().as_default() as G: ...@@ -49,5 +49,7 @@ with tf.Graph().as_default() as G:
var_dict[name] = v var_dict[name] = v
logger.info("Variables to dump:") logger.info("Variables to dump:")
logger.info(", ".join(var_dict.keys())) logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver(var_list=var_dict) saver = tf.train.Saver(
var_list=var_dict,
write_version=tf.train.SaverDef.V2)
saver.save(sess, args.output, write_meta_graph=False) saver.save(sess, args.output, write_meta_graph=False)
...@@ -34,6 +34,7 @@ class GymEnv(RLEnvironment): ...@@ -34,6 +34,7 @@ class GymEnv(RLEnvironment):
if dumpdir: if dumpdir:
mkdir_p(dumpdir) mkdir_p(dumpdir)
self.gymenv.monitor.start(dumpdir) self.gymenv.monitor.start(dumpdir)
self.use_dir = dumpdir
self.reset_stat() self.reset_stat()
self.rwd_counter = StatCounter() self.rwd_counter = StatCounter()
...@@ -46,7 +47,8 @@ class GymEnv(RLEnvironment): ...@@ -46,7 +47,8 @@ class GymEnv(RLEnvironment):
self._ob = self.gymenv.reset() self._ob = self.gymenv.reset()
def finish_episode(self): def finish_episode(self):
self.gymenv.monitor.flush() if self.use_dir is not None:
self.gymenv.monitor.flush()
self.stats['score'].append(self.rwd_counter.sum) self.stats['score'].append(self.rwd_counter.sum)
def current_state(self): def current_state(self):
......
...@@ -33,10 +33,17 @@ class ModelSaver(Callback): ...@@ -33,10 +33,17 @@ class ModelSaver(Callback):
for key in self.var_collections: for key in self.var_collections:
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
self.path = os.path.join(logger.LOG_DIR, 'model') self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver( try:
var_list=ModelSaver._get_var_dict(vars), self.saver = tf.train.Saver(
max_to_keep=self.keep_recent, var_list=ModelSaver._get_var_dict(vars),
keep_checkpoint_every_n_hours=self.keep_freq) max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
write_version=tf.train.SaverDef.V2)
except:
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
self.meta_graph_written = False self.meta_graph_written = False
@staticmethod @staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment