Commit 543d9299 authored by Yuxin Wu's avatar Yuxin Wu

use global_step. replace step_per_epoch by steps_per_epoch

parent c8d40e69
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
## Breaking API changes. ## Breaking API changes.
tensorpack is still in early development, and API changes can happen. tensorpack is still in early development, and API changes can happen.
Usually the backward compatibilty is __preserved for several months__, with a deprecation warning, The backward compatibilty will be __preserved for several months__, with a deprecation warning,
so you won't need to look at here very often. so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`.
* 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in * 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in
TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe). TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe).
* 2017/01/25. `TrainConfig(callbacks=)` now takes a list of `Callback` instances. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/243e957fe6d62a0cfb5728bd77fb3e005d6603e4) * 2017/01/25. `TrainConfig(callbacks=)` now takes a list of `Callback` instances. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/243e957fe6d62a0cfb5728bd77fb3e005d6603e4)
......
...@@ -141,7 +141,7 @@ def get_config(): ...@@ -141,7 +141,7 @@ def get_config():
dataflow=dataset_train, # the DataFlow instance for training dataflow=dataset_train, # the DataFlow instance for training
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=[ callbacks=[
PeriodicTrigger(ModelSaver(), every_k_steps=100), # save the model after every epoch ModelSaver(), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow # Calculate both the cost and the error for this DataFlow
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta from abc import ABCMeta
import six import six
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name, get_global_step_value
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory'] __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
...@@ -15,8 +15,10 @@ class Callback(object): ...@@ -15,8 +15,10 @@ class Callback(object):
""" Base class for all callbacks """ Base class for all callbacks
Attributes: Attributes:
epoch_num(int): the epoch that have completed the update. epoch_num(int): the current epoch num, starting from 1.
local_step(int): the local step number in the current epoch. local_step(int): the current local step number (1-based) in the current epoch.
which is also the number of steps that have finished.
global_step(int): the number of global steps that have finished.
trainer(Trainer): the trainer. trainer(Trainer): the trainer.
graph(tf.Graph): the graph. graph(tf.Graph): the graph.
...@@ -33,6 +35,7 @@ class Callback(object): ...@@ -33,6 +35,7 @@ class Callback(object):
Args: Args:
trainer(Trainer): the trainer which calls the callback trainer(Trainer): the trainer which calls the callback
""" """
self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
with tf.name_scope(type(self).__name__): with tf.name_scope(type(self).__name__):
...@@ -45,6 +48,7 @@ class Callback(object): ...@@ -45,6 +48,7 @@ class Callback(object):
""" """
Called right before the first iteration. Called right before the first iteration.
""" """
self._starting_step = get_global_step_value()
self._before_train() self._before_train()
def _before_train(self): def _before_train(self):
...@@ -111,7 +115,14 @@ class Callback(object): ...@@ -111,7 +115,14 @@ class Callback(object):
@property @property
def local_step(self): def local_step(self):
return self.trainer.local_step # inside trainer, we're still in the 'local_step' loop, so the number is off by 1
return self.trainer.local_step + 1
@property
def global_step(self):
return self._starting_step + \
self._steps_per_epoch * (self.epoch_num - 1) + \
self.local_step
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
......
...@@ -56,7 +56,7 @@ class MaintainStepCounter(Callback): ...@@ -56,7 +56,7 @@ class MaintainStepCounter(Callback):
gs_var, 1, gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME) name=GLOBAL_STEP_INCR_OP_NAME)
tf.mod( tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch, self.gs_incr_var, self.trainer.config.steps_per_epoch,
name=LOCAL_STEP_OP_NAME) name=LOCAL_STEP_OP_NAME)
def _before_train(self): def _before_train(self):
...@@ -71,12 +71,12 @@ class MaintainStepCounter(Callback): ...@@ -71,12 +71,12 @@ class MaintainStepCounter(Callback):
class ProgressBar(Callback): class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """ """ A progress bar based on tqdm. Enabled by default. """
def _before_train(self): def _before_train(self):
self._total = self.trainer.config.step_per_epoch self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
def _trigger_step(self, *args): def _trigger_step(self, *args):
if self.local_step == 0: if self.local_step == 1:
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update() self._bar.update()
if self.local_step == self._total - 1: if self.local_step == self._total:
self._bar.close() self._bar.close()
...@@ -76,5 +76,5 @@ class PeriodicTrigger(ProxyCallback): ...@@ -76,5 +76,5 @@ class PeriodicTrigger(ProxyCallback):
def _trigger_epoch(self, *args): def _trigger_epoch(self, *args):
if self._epoch_k is None: if self._epoch_k is None:
return return
if self.local_step % self._epoch_k == 0: if self.epoch_num % self._epoch_k == 0:
self.cb.trigger() self.cb.trigger()
...@@ -163,7 +163,7 @@ class Trainer(object): ...@@ -163,7 +163,7 @@ class Trainer(object):
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
for self.local_step in range(self.config.step_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self.coord.should_stop(): if self.coord.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass fetch_data = self.run_step() # implemented by subclass
......
...@@ -28,7 +28,7 @@ class TrainConfig(object): ...@@ -28,7 +28,7 @@ class TrainConfig(object):
callbacks=None, extra_callbacks=None, callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(), session_config=get_default_sess_config(),
session_init=None, session_init=None,
starting_epoch=1, step_per_epoch=None, max_epoch=99999, starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0], nr_tower=1, tower=None, predict_tower=[0],
**kwargs): **kwargs):
""" """
...@@ -48,7 +48,7 @@ class TrainConfig(object): ...@@ -48,7 +48,7 @@ class TrainConfig(object):
session_config (tf.ConfigProto): the config used to instantiate the session. session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session. session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
step_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch. steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size. Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training. max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers. nr_tower (int): number of training towers.
...@@ -103,21 +103,26 @@ class TrainConfig(object): ...@@ -103,21 +103,26 @@ class TrainConfig(object):
self.session_init = session_init self.session_init = session_init
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.step_per_epoch = step_per_epoch if steps_per_epoch is None:
if self.step_per_epoch is None: steps_per_epoch = kwargs.pop('step_per_epoch', None)
if steps_per_epoch is not None:
# TODO deprecate @Mar.27
logger.warn("[Deprecated] Use steps_per_epoch instead of step_per_epoch!")
if steps_per_epoch is None:
try: try:
if dataflow is not None: if dataflow is not None:
self.step_per_epoch = self.dataflow.size() steps_per_epoch = self.dataflow.size()
else: else:
self.step_per_epoch = self.data.size() steps_per_epoch = self.data.size()
except NotImplementedError: except NotImplementedError:
logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.") logger.exception("You must set `steps_per_epoch` if dataset.size() is not implemented.")
else: else:
self.step_per_epoch = int(self.step_per_epoch) steps_per_epoch = int(steps_per_epoch)
self.steps_per_epoch = steps_per_epoch
self.starting_epoch = int(starting_epoch) self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
assert self.step_per_epoch >= 0 and self.max_epoch > 0 assert self.steps_per_epoch >= 0 and self.max_epoch > 0
self.nr_tower = nr_tower self.nr_tower = nr_tower
if tower is not None: if tower is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment