Commit 741f404d authored by Yuxin Wu's avatar Yuxin Wu

fix dump callback

parent 47cbfe87
...@@ -30,7 +30,6 @@ class Callback(object): ...@@ -30,7 +30,6 @@ class Callback(object):
def before_train(self, trainer): def before_train(self, trainer):
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self.epoch_num = 0 self.epoch_num = 0
self._before_train() self._before_train()
......
...@@ -32,7 +32,7 @@ class DumpParamAsImage(Callback): ...@@ -32,7 +32,7 @@ class DumpParamAsImage(Callback):
self.var = self.graph.get_tensor_by_name(self.var_name) self.var = self.graph.get_tensor_by_name(self.var_name)
def _trigger_epoch(self): def _trigger_epoch(self):
val = self.sess.run(self.var) val = self.trainer.sess.run(self.var)
if self.func is not None: if self.func is not None:
val = self.func(val) val = self.func(val)
if isinstance(val, list): if isinstance(val, list):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment