Commit a5652699 authored by Yuxin Wu's avatar Yuxin Wu

some clean-ups, and add an alias `hooked_sess` (#147)

parent ccf4a5a0
...@@ -88,8 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase): ...@@ -88,8 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def run_step(self): def run_step(self):
for k in range(5): for k in range(5):
self.monitored_sess.run(self.d_min) self.hooked_sess.run(self.d_min)
self.monitored_sess.run(self.g_min) self.hooked_sess.run(self.g_min)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -74,10 +74,10 @@ class Callback(object): ...@@ -74,10 +74,10 @@ class Callback(object):
Same as ``tf.train.SessionRunHook.before_run``. Same as ``tf.train.SessionRunHook.before_run``.
""" """
fetches = self._before_run(ctx) fetches = self._before_run(ctx)
if isinstance(fetches, tf.train.SessionRunArgs):
return fetches
if fetches is None: if fetches is None:
return None return None
if isinstance(fetches, tf.train.SessionRunArgs):
return fetches
# also support list of names # also support list of names
assert isinstance(fetches, list), fetches assert isinstance(fetches, list), fetches
......
...@@ -10,8 +10,7 @@ from six.moves import zip ...@@ -10,8 +10,7 @@ from six.moves import zip
import tqdm import tqdm
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME, from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
LOCAL_STEP_OP_NAME)
from ..tfutils.common import ( from ..tfutils.common import (
get_op_tensor_name, get_global_step_var, get_op_tensor_name, get_global_step_var,
get_global_step_value, get_op_or_tensor_by_name) get_global_step_value, get_op_or_tensor_by_name)
...@@ -61,9 +60,10 @@ class MaintainStepCounter(Callback): ...@@ -61,9 +60,10 @@ class MaintainStepCounter(Callback):
self.gs_incr_var = tf.assign_add( self.gs_incr_var = tf.assign_add(
gs_var, 1, gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME) name=GLOBAL_STEP_INCR_OP_NAME)
tf.mod( # tf.mod(
self.gs_incr_var, self.trainer.config.steps_per_epoch, # self.gs_incr_var, self.trainer.config.steps_per_epoch,
name=LOCAL_STEP_OP_NAME) # name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var)
def _before_train(self): def _before_train(self):
gs_val = get_global_step_value() gs_val = get_global_step_value()
...@@ -75,7 +75,7 @@ class MaintainStepCounter(Callback): ...@@ -75,7 +75,7 @@ class MaintainStepCounter(Callback):
# increase global_step, when trainer.local_step changed # increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated: if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step self._last_updated = self.trainer.local_step
return [self.gs_incr_var.op] return self._fetches
else: else:
return None return None
...@@ -93,12 +93,14 @@ class ProgressBar(Callback): ...@@ -93,12 +93,14 @@ class ProgressBar(Callback):
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
def _before_train(self): def _before_train(self):
self._fetches = get_op_or_tensor_by_name(self._names)
self._last_updated = self.trainer.local_step self._last_updated = self.trainer.local_step
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
if len(self._names):
self._fetches = get_op_or_tensor_by_name(self._names) or None
if self._fetches:
self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _before_run(self, _): def _before_run(self, _):
...@@ -114,7 +116,7 @@ class ProgressBar(Callback): ...@@ -114,7 +116,7 @@ class ProgressBar(Callback):
def _after_run(self, _, run_values): def _after_run(self, _, run_values):
res = run_values.results res = run_values.results
if len(res): if res:
self._bar.set_postfix(zip(self._tags, res)) self._bar.set_postfix(zip(self._tags, res))
def _trigger_step(self): def _trigger_step(self):
......
...@@ -8,16 +8,14 @@ from six.moves import map ...@@ -8,16 +8,14 @@ from six.moves import map
from ..utils.naming import ( from ..utils.naming import (
GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME, GLOBAL_STEP_OP_NAME)
LOCAL_STEP_VAR_NAME)
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_local_step_var', #'get_local_step_var',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
...@@ -75,13 +73,13 @@ def get_global_step_value(): ...@@ -75,13 +73,13 @@ def get_global_step_value():
get_global_step_var()) get_global_step_var())
@memoized # @memoized
def get_local_step_var(): # def get_local_step_var():
try: # try:
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME) # return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except KeyError: # except KeyError:
logger.warn("get_local_step_var() is only available to use in callbacks!") # logger.warn("get_local_step_var() is only available to use in callbacks!")
raise # raise
def get_op_tensor_name(name): def get_op_tensor_name(name):
......
...@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
decay, num_updates=get_global_step_var(), name='EMA') decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(v) avg_maintain_op = averager.apply(v)
for c in v: for c in v:
# TODO do this in the EMA callback?
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c)) tf.summary.scalar(name + '-summary', averager.average(c))
......
...@@ -140,7 +140,9 @@ class Trainer(object): ...@@ -140,7 +140,9 @@ class Trainer(object):
session_creator=tf.train.ChiefSessionCreator( session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config), scaffold=scaffold, config=self.config.session_config),
hooks=self.config.callbacks.get_hooks()) hooks=self.config.callbacks.get_hooks())
self.sess = self.monitored_sess._tf_sess() self.hooked_sess = self.monitored_sess # just create an alias
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
self.config.session_init._run_init(self.sess) self.config.session_init._run_init(self.sess)
@abstractmethod @abstractmethod
......
...@@ -48,7 +48,7 @@ class FeedfreeTrainerBase(Trainer): ...@@ -48,7 +48,7 @@ class FeedfreeTrainerBase(Trainer):
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost.""" """ Simply run ``self.train_op``, which minimizes the cost."""
self.monitored_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
# if not hasattr(self, 'cnt'): # if not hasattr(self, 'cnt'):
# self.cnt = 0 # self.cnt = 0
# else: # else:
......
...@@ -87,7 +87,7 @@ class SimpleTrainer(Trainer): ...@@ -87,7 +87,7 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
self.monitored_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
self._input_method._setup(self) self._input_method._setup(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment