Commit 6e7eef04 authored by Yuxin Wu's avatar Yuxin Wu

bug fix in expreplay

parent ec726f6c
...@@ -52,6 +52,14 @@ class ExpReplay(DataFlow, Callback): ...@@ -52,6 +52,14 @@ class ExpReplay(DataFlow, Callback):
def init_memory(self): def init_memory(self):
logger.info("Populating replay memory...") logger.info("Populating replay memory...")
# fill some for the history
old_exploration = self.exploration
self.exploration = 1
for k in range(self.history_len):
self._populate_exp()
self.exploration = old_exploration
with tqdm(total=self.populate_size) as pbar: with tqdm(total=self.populate_size) as pbar:
while len(self.mem) < self.populate_size: while len(self.mem) < self.populate_size:
self._populate_exp() self._populate_exp()
......
...@@ -90,7 +90,6 @@ class SaverRestore(SessionInit): ...@@ -90,7 +90,6 @@ class SaverRestore(SessionInit):
del vars_multimap[k] del vars_multimap[k]
yield ret yield ret
@staticmethod @staticmethod
def _read_checkpoint_vars(model_path): def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
...@@ -112,8 +111,11 @@ class SaverRestore(SessionInit): ...@@ -112,8 +111,11 @@ class SaverRestore(SessionInit):
name = new_name name = new_name
if name in vars_available: if name in vars_available:
var_dict[name].append(v) var_dict[name].append(v)
vars_available.remove(name)
else: else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name)) logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
#for name in vars_available:
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
return var_dict return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment