Commit 589a8a35 authored by Yuxin Wu's avatar Yuxin Wu

make progressbar a callback

parent a59e46cd
......@@ -10,7 +10,6 @@ import imp
from tensorpack import TowerContext, logger, ModelFromMetaGraph
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.utils.naming import EXTRA_SAVE_VARS_KEY
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file')
......@@ -44,7 +43,7 @@ with tf.Graph().as_default() as G:
varmanip.dump_session_params(args.output)
else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
var_dict = {}
for v in var:
name = varmanip.get_savename_from_varname(v.name)
......
......@@ -15,9 +15,10 @@ class Callback(object):
""" Base class for all callbacks
Attributes:
epoch_num(int): the number of epochs that have completed the update
trainer(Trainer): the trainer
graph(tf.Graph): the graph
epoch_num(int): the epoch that have completed the update.
step_num(int): the step number in the current epoch.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
Note:
These attributes are available only after (and including)
......@@ -34,7 +35,6 @@ class Callback(object):
"""
self.trainer = trainer
self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch - 1
with tf.name_scope(type(self).__name__):
self._setup_graph()
......@@ -91,7 +91,6 @@ class Callback(object):
"""
Triggered after every epoch.
"""
self.epoch_num += 1
self._trigger_epoch()
def _trigger_epoch(self):
......@@ -106,6 +105,14 @@ class Callback(object):
def _after_train(self):
pass
@property
def epoch_num(self):
return self.trainer.epoch_num
@property
def step_num(self):
return self.trainer.step_num
def __str__(self):
return type(self).__name__
......@@ -128,12 +135,15 @@ class ProxyCallback(Callback):
def _setup_graph(self):
self.cb.setup_graph(self.trainer)
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self):
self.cb.trigger_epoch()
def _trigger_step(self, *args):
self.cb.trigger_step(*args)
def _after_train(self):
self.cb.after_train()
def __str__(self):
return "Proxy-" + str(self.cb)
......
......@@ -112,7 +112,8 @@ class StatHolder(object):
class StatPrinter(Callback):
"""
A callback to control what stats to print. Print everything by default.
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
"""
def __init__(self, print_tag=None):
......
......@@ -8,13 +8,14 @@
import tensorflow as tf
import re
from six.moves import zip
import tqdm
from ..utils import logger
from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage']
__all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar']
class StepStatPrinter(Callback):
......@@ -38,7 +39,7 @@ class StepStatPrinter(Callback):
class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors
in every step, and summarize them.
in every step, and summarize them. Enabled by default.
"""
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95):
"""
......@@ -65,3 +66,17 @@ class SummaryMovingAverage(Callback):
def _extra_fetches(self):
return [self.ema_op]
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
def _before_train(self):
self._total = self.trainer.config.step_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
def _trigger_step(self, *args):
if self.step_num == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update()
if self.step_num == self._total - 1:
self._bar.__exit__()
......@@ -7,11 +7,10 @@ import re
import weakref
import six
from six.moves import range
import tqdm
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger, get_tqdm_kwargs
from ..utils import logger
from ..utils.timer import timed_operation
from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var
......@@ -33,14 +32,18 @@ class Trainer(object):
""" Base class for a trainer.
Attributes:
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
epoch_num (int): the current epoch number.
step_num (int): the current step number (in an epoch).
"""
def __init__(self, config):
......@@ -54,6 +57,9 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
self.epoch_num = self.config.starting_epoch
self.step_num = 0
def train(self):
""" Start training """
self.setup()
......@@ -165,15 +171,13 @@ class Trainer(object):
try:
callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step()))
for epoch_num in range(
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation(
'Epoch {} (global_step {})'.format(
epoch_num, get_global_step() + self.config.step_per_epoch),
self.epoch_num, get_global_step() + self.config.step_per_epoch),
log_start=True):
for step in tqdm.trange(
self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)):
for self.step_num in range(self.config.step_per_epoch):
if self.coord.should_stop():
return
fetch_data = self.run_step() # implemented by subclass
......
......@@ -4,7 +4,9 @@
import tensorflow as tf
from ..callbacks import Callbacks, SummaryMovingAverage, StatPrinter
from ..callbacks import (
Callbacks, SummaryMovingAverage,
StatPrinter, ProgressBar)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -38,8 +40,8 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), StatPrinter()]``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
``[SummaryMovingAverage(), ProgressBar(), StatPrinter()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
......@@ -80,7 +82,7 @@ class TrainConfig(object):
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), StatPrinter()]
extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()]
self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment