Commit 6e7eef04 authored by Yuxin Wu's avatar Yuxin Wu

bug fix in expreplay

parent ec726f6c
......@@ -52,6 +52,14 @@ class ExpReplay(DataFlow, Callback):
def init_memory(self):
logger.info("Populating replay memory...")
# fill some for the history
old_exploration = self.exploration
self.exploration = 1
for k in range(self.history_len):
self._populate_exp()
self.exploration = old_exploration
with tqdm(total=self.populate_size) as pbar:
while len(self.mem) < self.populate_size:
self._populate_exp()
......
......@@ -90,7 +90,6 @@ class SaverRestore(SessionInit):
del vars_multimap[k]
yield ret
@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
......@@ -112,8 +111,11 @@ class SaverRestore(SessionInit):
name = new_name
if name in vars_available:
var_dict[name].append(v)
vars_available.remove(name)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
#for name in vars_available:
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
return var_dict
class ParamRestore(SessionInit):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment