Commit 543d9299 authored by Yuxin Wu's avatar Yuxin Wu

use global_step. replace step_per_epoch by steps_per_epoch

parent c8d40e69
......@@ -2,12 +2,13 @@
## Breaking API changes.
tensorpack is still in early development, and API changes can happen.
Usually the backward compatibilty is __preserved for several months__, with a deprecation warning,
The backward compatibilty will be __preserved for several months__, with a deprecation warning,
so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`.
* 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in
TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe).
* 2017/01/25. `TrainConfig(callbacks=)` now takes a list of `Callback` instances. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/243e957fe6d62a0cfb5728bd77fb3e005d6603e4)
......
......@@ -141,7 +141,7 @@ def get_config():
dataflow=dataset_train, # the DataFlow instance for training
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_steps=100), # save the model after every epoch
ModelSaver(), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from abc import ABCMeta
import six
from ..tfutils.common import get_op_or_tensor_by_name
from ..tfutils.common import get_op_or_tensor_by_name, get_global_step_value
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
......@@ -15,8 +15,10 @@ class Callback(object):
""" Base class for all callbacks
Attributes:
epoch_num(int): the epoch that have completed the update.
local_step(int): the local step number in the current epoch.
epoch_num(int): the current epoch num, starting from 1.
local_step(int): the current local step number (1-based) in the current epoch.
which is also the number of steps that have finished.
global_step(int): the number of global steps that have finished.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
......@@ -33,6 +35,7 @@ class Callback(object):
Args:
trainer(Trainer): the trainer which calls the callback
"""
self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer
self.graph = tf.get_default_graph()
with tf.name_scope(type(self).__name__):
......@@ -45,6 +48,7 @@ class Callback(object):
"""
Called right before the first iteration.
"""
self._starting_step = get_global_step_value()
self._before_train()
def _before_train(self):
......@@ -111,7 +115,14 @@ class Callback(object):
@property
def local_step(self):
return self.trainer.local_step
# inside trainer, we're still in the 'local_step' loop, so the number is off by 1
return self.trainer.local_step + 1
@property
def global_step(self):
return self._starting_step + \
self._steps_per_epoch * (self.epoch_num - 1) + \
self.local_step
def __str__(self):
return type(self).__name__
......
......@@ -56,7 +56,7 @@ class MaintainStepCounter(Callback):
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME)
tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch,
self.gs_incr_var, self.trainer.config.steps_per_epoch,
name=LOCAL_STEP_OP_NAME)
def _before_train(self):
......@@ -71,12 +71,12 @@ class MaintainStepCounter(Callback):
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
def _before_train(self):
self._total = self.trainer.config.step_per_epoch
self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
def _trigger_step(self, *args):
if self.local_step == 0:
if self.local_step == 1:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update()
if self.local_step == self._total - 1:
if self.local_step == self._total:
self._bar.close()
......@@ -76,5 +76,5 @@ class PeriodicTrigger(ProxyCallback):
def _trigger_epoch(self, *args):
if self._epoch_k is None:
return
if self.local_step % self._epoch_k == 0:
if self.epoch_num % self._epoch_k == 0:
self.cb.trigger()
......@@ -163,7 +163,7 @@ class Trainer(object):
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
for self.local_step in range(self.config.step_per_epoch):
for self.local_step in range(self.config.steps_per_epoch):
if self.coord.should_stop():
return
fetch_data = self.run_step() # implemented by subclass
......
......@@ -28,7 +28,7 @@ class TrainConfig(object):
callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(),
session_init=None,
starting_epoch=1, step_per_epoch=None, max_epoch=99999,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0],
**kwargs):
"""
......@@ -48,7 +48,7 @@ class TrainConfig(object):
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
step_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
......@@ -103,21 +103,26 @@ class TrainConfig(object):
self.session_init = session_init
assert_type(self.session_init, SessionInit)
self.step_per_epoch = step_per_epoch
if self.step_per_epoch is None:
if steps_per_epoch is None:
steps_per_epoch = kwargs.pop('step_per_epoch', None)
if steps_per_epoch is not None:
# TODO deprecate @Mar.27
logger.warn("[Deprecated] Use steps_per_epoch instead of step_per_epoch!")
if steps_per_epoch is None:
try:
if dataflow is not None:
self.step_per_epoch = self.dataflow.size()
steps_per_epoch = self.dataflow.size()
else:
self.step_per_epoch = self.data.size()
steps_per_epoch = self.data.size()
except NotImplementedError:
logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.")
logger.exception("You must set `steps_per_epoch` if dataset.size() is not implemented.")
else:
self.step_per_epoch = int(self.step_per_epoch)
steps_per_epoch = int(steps_per_epoch)
self.steps_per_epoch = steps_per_epoch
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
assert self.step_per_epoch >= 0 and self.max_epoch > 0
assert self.steps_per_epoch >= 0 and self.max_epoch > 0
self.nr_tower = nr_tower
if tower is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment