Commit a6aa8bce authored by Yuxin Wu's avatar Yuxin Wu

Define global_step to be the number of hooked_sess.run calls

parent b747c068
...@@ -83,7 +83,6 @@ if __name__ == '__main__': ...@@ -83,7 +83,6 @@ if __name__ == '__main__':
max_epoch=200, max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
""" # The original code uses a different schedule, but this seems to work well.
The original code uses a different schedule, but this seems to work well. # Train 1 D after 2 G
"""
SeparateGANTrainer(config, d_period=3).train() SeparateGANTrainer(config, d_period=3).train()
...@@ -19,9 +19,9 @@ class Callback(object): ...@@ -19,9 +19,9 @@ class Callback(object):
for more detailed explanation of the callback methods. for more detailed explanation of the callback methods.
Attributes: Attributes:
epoch_num(int): the number of the current epoch. epoch_num(int): trainer.epoch_num
global_step(int): the number of global steps that have finished or is currently running. global_step(int): trainer.global_step
local_step(int): the local steps within the current epoch. local_step(int): trainer.local_step
trainer(Trainer): the trainer. trainer(Trainer): the trainer.
graph(tf.Graph): the graph. graph(tf.Graph): the graph.
......
...@@ -10,14 +10,11 @@ import tqdm ...@@ -10,14 +10,11 @@ import tqdm
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import ( from ..tfutils.common import (
get_op_tensor_name, get_global_step_var, get_op_tensor_name, get_op_or_tensor_by_name)
get_global_step_value, get_op_or_tensor_by_name)
from .base import Callback from .base import Callback
__all__ = ['StepTensorPrinter', 'MaintainStepCounter', __all__ = ['StepTensorPrinter', 'ProgressBar']
'ProgressBar']
class StepTensorPrinter(Callback): class StepTensorPrinter(Callback):
...@@ -47,39 +44,6 @@ class StepTensorPrinter(Callback): ...@@ -47,39 +44,6 @@ class StepTensorPrinter(Callback):
logger.info("{}: {}".format(n, v)) logger.info("{}: {}".format(n, v))
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one in every `run_step` call.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with tf.device(gs_var.device):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
# tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
gs_val = get_global_step_value()
if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.local_step
def _before_run(self, _):
# increase global_step, when trainer.local_step changed
if self.local_step != self._last_updated:
self._last_updated = self.local_step
return self._fetches
else:
return None
class ProgressBar(Callback): class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """ """ A progress bar based on tqdm. Enabled by default. """
......
...@@ -8,17 +8,19 @@ from six.moves import range ...@@ -8,17 +8,19 @@ from six.moves import range
import tensorflow as tf import tensorflow as tf
from ..graph_builder.predictor_factory import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value, get_global_step_var
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -29,6 +31,34 @@ class StopTraining(BaseException): ...@@ -29,6 +31,34 @@ class StopTraining(BaseException):
pass pass
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with tf.device(gs_var.device):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
if self.global_step != 0:
logger.info("Start training with global_step={}".format(self.global_step))
def _before_run(self, _):
# always increase global_step when hooked_sess.run is called
return self._fetches
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
class Trainer(object): class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
...@@ -38,7 +68,7 @@ class Trainer(object): ...@@ -38,7 +68,7 @@ class Trainer(object):
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks. hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging. monitors (Monitors): the monitors. Callbacks can use it for logging.
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of (tensorpack) steps that have finished in the current epoch.
""" """
# step attr only available after before_train? # step attr only available after before_train?
...@@ -58,6 +88,7 @@ class Trainer(object): ...@@ -58,6 +88,7 @@ class Trainer(object):
self._callbacks = [] self._callbacks = []
self.monitors = [] self.monitors = []
self._epoch_num = None self._epoch_num = None
self._global_step = 0
self._setup() # subclass will setup the graph and InputSource self._setup() # subclass will setup the graph and InputSource
...@@ -102,24 +133,18 @@ class Trainer(object): ...@@ -102,24 +133,18 @@ class Trainer(object):
def run_step(self): def run_step(self):
""" """
Defines what to do in one iteration, by default is: Defines what to do in one iteration. The default is:
``self.hooked_sess.run(self.train_op)``. ``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``, The behavior can be changed by either defining what is ``train_op``,
or overriding this method. or overriding this method.
""" """
assert hasattr(self, 'train_op'), \ if not hasattr(self, 'train_op'):
"Please either set `Trainer.train_op` or provide an implementation " \ raise NotImplementedError(
"of Trainer.run_step()!" "Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
def _setup_input_source(self, input_source):
"""
Setup InputSource on this trainer.
"""
cbs = input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
def setup(self): def setup(self):
""" """
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
...@@ -175,25 +200,25 @@ class Trainer(object): ...@@ -175,25 +200,25 @@ class Trainer(object):
@property @property
def global_step(self): def global_step(self):
""" """
The number of steps that have finished or is currently running. The tensorflow global_step, i.e. how many times `hooked_sess.run` has been called.
Note:
1. global_step is incremented **after** each `hooked_sess.run` returns from TF runtime.
2. If you make zero or more than one calls to `hooked_sess.run` in one
`run_step`, local_step and global_step may increment at different speed.
""" """
try: return self._global_step
return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - self.config.starting_epoch) + \
self.local_step + 1 # +1: the ongoing step
except AttributeError:
return get_global_step_value()
def main_loop(self): def main_loop(self):
""" """
Run the main training loop. Run the main training loop.
""" """
with self.sess.as_default(): with self.sess.as_default():
self._starting_step = get_global_step_value() self._global_step = get_global_step_value()
try: try:
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly # refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value() self._global_step = get_global_step_value()
for self._epoch_num in range( for self._epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self._epoch_num)) logger.info("Start Epoch {} ...".format(self._epoch_num))
...@@ -221,7 +246,7 @@ class Trainer(object): ...@@ -221,7 +246,7 @@ class Trainer(object):
self._callbacks.after_train() self._callbacks.after_train()
self.hooked_sess.close() self.hooked_sess.close()
# Predictor related methods: # Predictor related methods. They actually should not be part of a trainer:
@property @property
def vs_name_for_predictor(self): def vs_name_for_predictor(self):
""" """
......
...@@ -28,7 +28,8 @@ class FeedfreeTrainerBase(Trainer): ...@@ -28,7 +28,8 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self): def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._setup_input_source(self._input_source) cbs = self._setup_input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
# deprecated # deprecated
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment