Commit a5652699 authored by Yuxin Wu's avatar Yuxin Wu

some clean-ups, and add an alias `hooked_sess` (#147)

parent ccf4a5a0
......@@ -88,8 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def run_step(self):
for k in range(5):
self.monitored_sess.run(self.d_min)
self.monitored_sess.run(self.g_min)
self.hooked_sess.run(self.d_min)
self.hooked_sess.run(self.g_min)
if __name__ == '__main__':
......
......@@ -74,10 +74,10 @@ class Callback(object):
Same as ``tf.train.SessionRunHook.before_run``.
"""
fetches = self._before_run(ctx)
if isinstance(fetches, tf.train.SessionRunArgs):
return fetches
if fetches is None:
return None
if isinstance(fetches, tf.train.SessionRunArgs):
return fetches
# also support list of names
assert isinstance(fetches, list), fetches
......
......@@ -10,8 +10,7 @@ from six.moves import zip
import tqdm
from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME,
LOCAL_STEP_OP_NAME)
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import (
get_op_tensor_name, get_global_step_var,
get_global_step_value, get_op_or_tensor_by_name)
......@@ -61,9 +60,10 @@ class MaintainStepCounter(Callback):
self.gs_incr_var = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME)
tf.mod(
self.gs_incr_var, self.trainer.config.steps_per_epoch,
name=LOCAL_STEP_OP_NAME)
# tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var)
def _before_train(self):
gs_val = get_global_step_value()
......@@ -75,7 +75,7 @@ class MaintainStepCounter(Callback):
# increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
return [self.gs_incr_var.op]
return self._fetches
else:
return None
......@@ -93,12 +93,14 @@ class ProgressBar(Callback):
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
def _before_train(self):
self._fetches = get_op_or_tensor_by_name(self._names)
self._last_updated = self.trainer.local_step
self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
if len(self._names):
self._fetches = get_op_or_tensor_by_name(self._names) or None
if self._fetches:
self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _before_run(self, _):
......@@ -114,7 +116,7 @@ class ProgressBar(Callback):
def _after_run(self, _, run_values):
res = run_values.results
if len(res):
if res:
self._bar.set_postfix(zip(self._tags, res))
def _trigger_step(self):
......
......@@ -8,16 +8,14 @@ from six.moves import map
from ..utils.naming import (
GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME,
LOCAL_STEP_VAR_NAME)
from ..utils import logger
GLOBAL_STEP_OP_NAME)
from ..utils.argtools import memoized
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
'get_local_step_var',
#'get_local_step_var',
'get_op_tensor_name',
'get_tensors_by_names',
......@@ -75,13 +73,13 @@ def get_global_step_value():
get_global_step_var())
@memoized
def get_local_step_var():
try:
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except KeyError:
logger.warn("get_local_step_var() is only available to use in callbacks!")
raise
# @memoized
# def get_local_step_var():
# try:
# return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
# except KeyError:
# logger.warn("get_local_step_var() is only available to use in callbacks!")
# raise
def get_op_tensor_name(name):
......
......@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(v)
for c in v:
# TODO do this in the EMA callback?
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
......
......@@ -140,7 +140,9 @@ class Trainer(object):
session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config),
hooks=self.config.callbacks.get_hooks())
self.sess = self.monitored_sess._tf_sess()
self.hooked_sess = self.monitored_sess # just create an alias
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
self.config.session_init._run_init(self.sess)
@abstractmethod
......
......@@ -48,7 +48,7 @@ class FeedfreeTrainerBase(Trainer):
def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost."""
self.monitored_sess.run(self.train_op)
self.hooked_sess.run(self.train_op)
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# else:
......
......@@ -87,7 +87,7 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_method.next_feed()
self.monitored_sess.run(self.train_op, feed_dict=feed)
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self):
self._input_method._setup(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment