Commit 589a8a35 authored by Yuxin Wu's avatar Yuxin Wu

make progressbar a callback

parent a59e46cd
...@@ -10,7 +10,6 @@ import imp ...@@ -10,7 +10,6 @@ import imp
from tensorpack import TowerContext, logger, ModelFromMetaGraph from tensorpack import TowerContext, logger, ModelFromMetaGraph
from tensorpack.tfutils import sessinit, varmanip from tensorpack.tfutils import sessinit, varmanip
from tensorpack.utils.naming import EXTRA_SAVE_VARS_KEY
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file') parser.add_argument('--config', help='config file')
...@@ -44,7 +43,7 @@ with tf.Graph().as_default() as G: ...@@ -44,7 +43,7 @@ with tf.Graph().as_default() as G:
varmanip.dump_session_params(args.output) varmanip.dump_session_params(args.output)
else: else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
var_dict = {} var_dict = {}
for v in var: for v in var:
name = varmanip.get_savename_from_varname(v.name) name = varmanip.get_savename_from_varname(v.name)
......
...@@ -15,9 +15,10 @@ class Callback(object): ...@@ -15,9 +15,10 @@ class Callback(object):
""" Base class for all callbacks """ Base class for all callbacks
Attributes: Attributes:
epoch_num(int): the number of epochs that have completed the update epoch_num(int): the epoch that have completed the update.
trainer(Trainer): the trainer step_num(int): the step number in the current epoch.
graph(tf.Graph): the graph trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
Note: Note:
These attributes are available only after (and including) These attributes are available only after (and including)
...@@ -34,7 +35,6 @@ class Callback(object): ...@@ -34,7 +35,6 @@ class Callback(object):
""" """
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch - 1
with tf.name_scope(type(self).__name__): with tf.name_scope(type(self).__name__):
self._setup_graph() self._setup_graph()
...@@ -91,7 +91,6 @@ class Callback(object): ...@@ -91,7 +91,6 @@ class Callback(object):
""" """
Triggered after every epoch. Triggered after every epoch.
""" """
self.epoch_num += 1
self._trigger_epoch() self._trigger_epoch()
def _trigger_epoch(self): def _trigger_epoch(self):
...@@ -106,6 +105,14 @@ class Callback(object): ...@@ -106,6 +105,14 @@ class Callback(object):
def _after_train(self): def _after_train(self):
pass pass
@property
def epoch_num(self):
return self.trainer.epoch_num
@property
def step_num(self):
return self.trainer.step_num
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
...@@ -128,12 +135,15 @@ class ProxyCallback(Callback): ...@@ -128,12 +135,15 @@ class ProxyCallback(Callback):
def _setup_graph(self): def _setup_graph(self):
self.cb.setup_graph(self.trainer) self.cb.setup_graph(self.trainer)
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.trigger_epoch() self.cb.trigger_epoch()
def _trigger_step(self, *args):
self.cb.trigger_step(*args)
def _after_train(self):
self.cb.after_train()
def __str__(self): def __str__(self):
return "Proxy-" + str(self.cb) return "Proxy-" + str(self.cb)
......
...@@ -112,7 +112,8 @@ class StatHolder(object): ...@@ -112,7 +112,8 @@ class StatHolder(object):
class StatPrinter(Callback): class StatPrinter(Callback):
""" """
A callback to control what stats to print. Print everything by default. A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
""" """
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
......
...@@ -8,13 +8,14 @@ ...@@ -8,13 +8,14 @@
import tensorflow as tf import tensorflow as tf
import re import re
from six.moves import zip from six.moves import zip
import tqdm
from ..utils import logger from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import MOVING_SUMMARY_VARS_KEY from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from ..tfutils.common import get_op_tensor_name, get_global_step_var from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .base import Callback from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage'] __all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar']
class StepStatPrinter(Callback): class StepStatPrinter(Callback):
...@@ -38,7 +39,7 @@ class StepStatPrinter(Callback): ...@@ -38,7 +39,7 @@ class StepStatPrinter(Callback):
class SummaryMovingAverage(Callback): class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors """ Maintain the moving average of the tensors
in every step, and summarize them. in every step, and summarize them. Enabled by default.
""" """
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95): def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95):
""" """
...@@ -65,3 +66,17 @@ class SummaryMovingAverage(Callback): ...@@ -65,3 +66,17 @@ class SummaryMovingAverage(Callback):
def _extra_fetches(self): def _extra_fetches(self):
return [self.ema_op] return [self.ema_op]
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
def _before_train(self):
self._total = self.trainer.config.step_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
def _trigger_step(self, *args):
if self.step_num == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update()
if self.step_num == self._total - 1:
self._bar.__exit__()
...@@ -7,11 +7,10 @@ import re ...@@ -7,11 +7,10 @@ import re
import weakref import weakref
import six import six
from six.moves import range from six.moves import range
import tqdm
import tensorflow as tf import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger, get_tqdm_kwargs from ..utils import logger
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var from ..tfutils import get_global_step, get_global_step_var
...@@ -33,14 +32,18 @@ class Trainer(object): ...@@ -33,14 +32,18 @@ class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
Attributes: Attributes:
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
config (TrainConfig): the config used in this trainer. config (TrainConfig): the config used in this trainer.
model (ModelDesc) model (ModelDesc)
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
coord (tf.train.Coordinator) coord (tf.train.Coordinator)
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`. extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
epoch_num (int): the current epoch number.
step_num (int): the current step number (in an epoch).
""" """
def __init__(self, config): def __init__(self, config):
...@@ -54,6 +57,9 @@ class Trainer(object): ...@@ -54,6 +57,9 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
self.epoch_num = self.config.starting_epoch
self.step_num = 0
def train(self): def train(self):
""" Start training """ """ Start training """
self.setup() self.setup()
...@@ -165,15 +171,13 @@ class Trainer(object): ...@@ -165,15 +171,13 @@ class Trainer(object):
try: try:
callbacks.before_train() callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step())) logger.info("Start training with global_step={}".format(get_global_step()))
for epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation( with timed_operation(
'Epoch {} (global_step {})'.format( 'Epoch {} (global_step {})'.format(
epoch_num, get_global_step() + self.config.step_per_epoch), self.epoch_num, get_global_step() + self.config.step_per_epoch),
log_start=True): log_start=True):
for step in tqdm.trange( for self.step_num in range(self.config.step_per_epoch):
self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)):
if self.coord.should_stop(): if self.coord.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass fetch_data = self.run_step() # implemented by subclass
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import tensorflow as tf import tensorflow as tf
from ..callbacks import Callbacks, SummaryMovingAverage, StatPrinter from ..callbacks import (
Callbacks, SummaryMovingAverage,
StatPrinter, ProgressBar)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -38,8 +40,8 @@ class TrainConfig(object): ...@@ -38,8 +40,8 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), StatPrinter()]``. The list of ``[SummaryMovingAverage(), ProgressBar(), StatPrinter()]``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
session_config (tf.ConfigProto): the config used to instantiate the session. session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session. session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
...@@ -80,7 +82,7 @@ class TrainConfig(object): ...@@ -80,7 +82,7 @@ class TrainConfig(object):
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter() callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list) assert_type(callbacks, list)
if extra_callbacks is None: if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), StatPrinter()] extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()]
self.callbacks = callbacks + extra_callbacks self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks) self.callbacks = Callbacks(self.callbacks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment