Commit eee05770 authored by Yuxin Wu's avatar Yuxin Wu

use hooks to run step triggers. examples unfixed. (#147)

parent 136174c9
...@@ -85,6 +85,7 @@ class Callback(object): ...@@ -85,6 +85,7 @@ class Callback(object):
if isinstance(f, (tf.Tensor, tf.Operation)): if isinstance(f, (tf.Tensor, tf.Operation)):
ret.append(f) ret.append(f)
else: else:
# warn about speed
ret.append(get_op_or_tensor_by_name(f)) ret.append(get_op_or_tensor_by_name(f))
return ret return ret
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict
import time import time
import traceback import traceback
...@@ -15,8 +14,20 @@ from ..utils import logger ...@@ -15,8 +14,20 @@ from ..utils import logger
__all__ = ['Callbacks'] __all__ = ['Callbacks']
class CallbackTimeLogger(object): class CallbackHook(tf.train.SessionRunHook):
def __init__(self, cb):
self.cb = cb
def before_run(self, _):
return tf.train.SessionRunArgs(
fetches=self.cb.extra_fetches())
def after_run(self, _, vals):
res = vals.results
self.cb.trigger_step(*res)
class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
self.times = [] self.times = []
self.tot = 0 self.tot = 0
...@@ -90,30 +101,9 @@ class Callbacks(Callback): ...@@ -90,30 +101,9 @@ class Callbacks(Callback):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
def _extra_fetches(self): def get_hooks(self):
if self._extra_fetches_cache is not None: # TODO skip
return self._extra_fetches_cache return [CallbackHook(cb) for cb in self.cbs]
# TODO use dispatch mechanism to avoid duplication
self._cbid_to_fetchid = defaultdict(list)
ret = []
for idx, cb in enumerate(self.cbs):
fetch = cb.extra_fetches()
if len(fetch) == 0:
continue
for f in fetch:
ret.append(f)
self._cbid_to_fetchid[idx].append(len(ret) - 1)
self._extra_fetches_cache = ret
return ret
def _trigger_step(self, *args):
for idx, cb in enumerate(self.cbs):
fid = self._cbid_to_fetchid[idx]
if len(fid) == 0:
cb.trigger_step()
else:
data = [args[k] for k in fid]
cb.trigger_step(*data)
def _trigger_epoch(self): def _trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
......
...@@ -116,11 +116,24 @@ def get_tensors_by_names(names): ...@@ -116,11 +116,24 @@ def get_tensors_by_names(names):
def get_op_or_tensor_by_name(name): def get_op_or_tensor_by_name(name):
"""
Get either tf.Operation of tf.Tensor from names.
Args:
name (list[str] or str): names of operations or tensors.
"""
G = tf.get_default_graph() G = tf.get_default_graph()
if len(name) >= 3 and name[-2] == ':':
return G.get_tensor_by_name(name) def f(n):
if len(n) >= 3 and n[-2] == ':':
return G.get_tensor_by_name(n)
else:
return G.get_operation_by_name(n)
if not isinstance(name, list):
return f(name)
else: else:
return G.get_operation_by_name(name) return map(f, name)
def get_name_scope_name(): def get_name_scope_name():
......
...@@ -72,7 +72,8 @@ class Trainer(object): ...@@ -72,7 +72,8 @@ class Trainer(object):
This function should only get called after :meth:`setup()` has finished. This function should only get called after :meth:`setup()` has finished.
""" """
return self._extra_fetches # TODO remove this func
return []
def trigger_epoch(self): def trigger_epoch(self):
""" """
...@@ -130,7 +131,6 @@ class Trainer(object): ...@@ -130,7 +131,6 @@ class Trainer(object):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches()
logger.info("Setup summaries ...") logger.info("Setup summaries ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
...@@ -149,7 +149,7 @@ class Trainer(object): ...@@ -149,7 +149,7 @@ class Trainer(object):
self.monitored_sess = tf.train.MonitoredSession( self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator( session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config), scaffold=scaffold, config=self.config.session_config),
hooks=None) hooks=self.config.callbacks.get_hooks())
self.sess = self.monitored_sess._tf_sess() self.sess = self.monitored_sess._tf_sess()
self.config.session_init._run_init(self.sess) self.config.session_init._run_init(self.sess)
...@@ -182,12 +182,7 @@ class Trainer(object): ...@@ -182,12 +182,7 @@ class Trainer(object):
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self.monitored_sess.should_stop(): if self.monitored_sess.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass self.run_step() # implemented by subclass
if fetch_data is None:
# old trainer doesn't return fetch data
callbacks.trigger_step()
else:
callbacks.trigger_step(*fetch_data)
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
......
...@@ -63,8 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -63,8 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost.""" """ Simply run ``self.train_op``, which minimizes the cost."""
ret = self.sess.run([self.train_op] + self.get_extra_fetches()) self.monitored_sess.run(self.train_op)
return ret[1:]
# if not hasattr(self, 'cnt'): # if not hasattr(self, 'cnt'):
# self.cnt = 0 # self.cnt = 0
# else: # else:
......
...@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer): ...@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
ret = self.sess.run([self.train_op] + self.get_extra_fetches(), self.monitored_sess.run(self.train_op, feed_dict=feed)
feed_dict=feed)
return ret[1:]
def _setup(self): def _setup(self):
self._input_method._setup(self) self._input_method._setup(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment