Commit da0984df authored by Yuxin Wu's avatar Yuxin Wu

use build_graph (fix #381)

parent f644da74
...@@ -21,7 +21,7 @@ class PeriodicTrigger(ProxyCallback): ...@@ -21,7 +21,7 @@ class PeriodicTrigger(ProxyCallback):
every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to
None to disable. None to disable.
every_k_steps and every_k_epochs can be both set, but cannot be both NOne. every_k_steps and every_k_epochs can be both set, but cannot be both None.
""" """
assert isinstance(triggerable, Callback), type(triggerable) assert isinstance(triggerable, Callback), type(triggerable)
super(PeriodicTrigger, self).__init__(triggerable) super(PeriodicTrigger, self).__init__(triggerable)
...@@ -55,7 +55,8 @@ class PeriodicRunHooks(ProxyCallback): ...@@ -55,7 +55,8 @@ class PeriodicRunHooks(ProxyCallback):
""" """
Args: Args:
callback (Callback): callback (Callback):
every_k_steps(int): every_k_steps(int): call ``{before,after}_run`` when
``global_step % k == 0``.
""" """
self._every_k_steps = int(every_k_steps) self._every_k_steps = int(every_k_steps)
super(PeriodicRunHooks, self).__init__(callback) super(PeriodicRunHooks, self).__init__(callback)
...@@ -67,3 +68,6 @@ class PeriodicRunHooks(ProxyCallback): ...@@ -67,3 +68,6 @@ class PeriodicRunHooks(ProxyCallback):
def _after_run(self, ctx, rv): def _after_run(self, ctx, rv):
if self.global_step % self._every_k_steps == 0: if self.global_step % self._every_k_steps == 0:
self.cb._after_run(ctx, rv) self.cb._after_run(ctx, rv)
def __str__(self):
return "PeriodicRunHooks-" + str(self.cb)
...@@ -89,7 +89,7 @@ class ModelExport(object): ...@@ -89,7 +89,7 @@ class ModelExport(object):
""" """
logger.info('[export] build model for %s' % checkpoint) logger.info('[export] build model for %s' % checkpoint)
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
self.model._build_graph(self.input) self.model.build_graph(self.input)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load values from latest checkpoint # load values from latest checkpoint
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment