Commit 741f404d authored by Yuxin Wu's avatar Yuxin Wu

fix dump callback

parent 47cbfe87
......@@ -30,7 +30,6 @@ class Callback(object):
def before_train(self, trainer):
self.trainer = trainer
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self.epoch_num = 0
self._before_train()
......
......@@ -32,7 +32,7 @@ class DumpParamAsImage(Callback):
self.var = self.graph.get_tensor_by_name(self.var_name)
def _trigger_epoch(self):
val = self.sess.run(self.var)
val = self.trainer.sess.run(self.var)
if self.func is not None:
val = self.func(val)
if isinstance(val, list):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment