Commit a6aa8bce authored by Yuxin Wu's avatar Yuxin Wu

Define global_step to be the number of hooked_sess.run calls

parent b747c068
......@@ -83,7 +83,6 @@ if __name__ == '__main__':
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
"""
The original code uses a different schedule, but this seems to work well.
"""
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer(config, d_period=3).train()
......@@ -19,9 +19,9 @@ class Callback(object):
for more detailed explanation of the callback methods.
Attributes:
epoch_num(int): the number of the current epoch.
global_step(int): the number of global steps that have finished or is currently running.
local_step(int): the local steps within the current epoch.
epoch_num(int): trainer.epoch_num
global_step(int): trainer.global_step
local_step(int): trainer.local_step
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
......
......@@ -10,14 +10,11 @@ import tqdm
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import (
get_op_tensor_name, get_global_step_var,
get_global_step_value, get_op_or_tensor_by_name)
get_op_tensor_name, get_op_or_tensor_by_name)
from .base import Callback
__all__ = ['StepTensorPrinter', 'MaintainStepCounter',
'ProgressBar']
__all__ = ['StepTensorPrinter', 'ProgressBar']
class StepTensorPrinter(Callback):
......@@ -47,39 +44,6 @@ class StepTensorPrinter(Callback):
logger.info("{}: {}".format(n, v))
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one in every `run_step` call.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with tf.device(gs_var.device):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
# tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
gs_val = get_global_step_value()
if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.local_step
def _before_run(self, _):
# increase global_step, when trainer.local_step changed
if self.local_step != self._last_updated:
self._last_updated = self.local_step
return self._fetches
else:
return None
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
......
......@@ -8,17 +8,19 @@ from six.moves import range
import tensorflow as tf
from ..graph_builder.predictor_factory import PredictorFactory
from .config import TrainConfig
from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils import get_global_step_value, get_global_step_var
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory
__all__ = ['Trainer', 'StopTraining']
......@@ -29,6 +31,34 @@ class StopTraining(BaseException):
pass
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with tf.device(gs_var.device):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
if self.global_step != 0:
logger.info("Start training with global_step={}".format(self.global_step))
def _before_run(self, _):
# always increase global_step when hooked_sess.run is called
return self._fetches
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
class Trainer(object):
""" Base class for a trainer.
......@@ -38,7 +68,7 @@ class Trainer(object):
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
local_step (int): the number of steps that have finished in the current epoch.
local_step (int): the number of (tensorpack) steps that have finished in the current epoch.
"""
# step attr only available after before_train?
......@@ -58,6 +88,7 @@ class Trainer(object):
self._callbacks = []
self.monitors = []
self._epoch_num = None
self._global_step = 0
self._setup() # subclass will setup the graph and InputSource
......@@ -102,24 +133,18 @@ class Trainer(object):
def run_step(self):
"""
Defines what to do in one iteration, by default is:
Defines what to do in one iteration. The default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
assert hasattr(self, 'train_op'), \
"Please either set `Trainer.train_op` or provide an implementation " \
"of Trainer.run_step()!"
if not hasattr(self, 'train_op'):
raise NotImplementedError(
"Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op)
def _setup_input_source(self, input_source):
"""
Setup InputSource on this trainer.
"""
cbs = input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
def setup(self):
"""
Setup the trainer and be ready for the main loop.
......@@ -175,25 +200,25 @@ class Trainer(object):
@property
def global_step(self):
"""
The number of steps that have finished or is currently running.
The tensorflow global_step, i.e. how many times `hooked_sess.run` has been called.
Note:
1. global_step is incremented **after** each `hooked_sess.run` returns from TF runtime.
2. If you make zero or more than one calls to `hooked_sess.run` in one
`run_step`, local_step and global_step may increment at different speed.
"""
try:
return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - self.config.starting_epoch) + \
self.local_step + 1 # +1: the ongoing step
except AttributeError:
return get_global_step_value()
return self._global_step
def main_loop(self):
"""
Run the main training loop.
"""
with self.sess.as_default():
self._starting_step = get_global_step_value()
self._global_step = get_global_step_value()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value()
self._global_step = get_global_step_value()
for self._epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self._epoch_num))
......@@ -221,7 +246,7 @@ class Trainer(object):
self._callbacks.after_train()
self.hooked_sess.close()
# Predictor related methods:
# Predictor related methods. They actually should not be part of a trainer:
@property
def vs_name_for_predictor(self):
"""
......
......@@ -28,7 +28,8 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._setup_input_source(self._input_source)
cbs = self._setup_input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
# deprecated
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment