Commit e309e627 authored by Yuxin Wu's avatar Yuxin Wu

update gym env

parent 17863b8b
......@@ -49,5 +49,7 @@ with tf.Graph().as_default() as G:
var_dict[name] = v
logger.info("Variables to dump:")
logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver(var_list=var_dict)
saver = tf.train.Saver(
var_list=var_dict,
write_version=tf.train.SaverDef.V2)
saver.save(sess, args.output, write_meta_graph=False)
......@@ -34,6 +34,7 @@ class GymEnv(RLEnvironment):
if dumpdir:
mkdir_p(dumpdir)
self.gymenv.monitor.start(dumpdir)
self.use_dir = dumpdir
self.reset_stat()
self.rwd_counter = StatCounter()
......@@ -46,7 +47,8 @@ class GymEnv(RLEnvironment):
self._ob = self.gymenv.reset()
def finish_episode(self):
self.gymenv.monitor.flush()
if self.use_dir is not None:
self.gymenv.monitor.flush()
self.stats['score'].append(self.rwd_counter.sum)
def current_state(self):
......
......@@ -33,10 +33,17 @@ class ModelSaver(Callback):
for key in self.var_collections:
vars.extend(tf.get_collection(key))
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
try:
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
write_version=tf.train.SaverDef.V2)
except:
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
self.meta_graph_written = False
@staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment