Commit 89bcdd10 authored by Yuxin Wu's avatar Yuxin Wu

add local_step as callback property. remove some legacy code.

parent 762e4dcc
......@@ -17,6 +17,7 @@ class Callback(object):
Attributes:
epoch_num(int): the number of the current epoch.
global_step(int): the number of global steps that have finished.
local_step(int): the local steps within the current epoch.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
......@@ -157,6 +158,10 @@ class Callback(object):
def global_step(self):
return self.trainer.global_step
@property
def local_step(self):
return self.trainer.local_step
def __str__(self):
return type(self).__name__
......
......@@ -49,9 +49,8 @@ class StepTensorPrinter(Callback):
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph and also creates the local step tensor.
This callback is always enabled by the trainer, and you wouldn't need to
use it.
It maintains the global step in the graph, making sure it's increased by one in every `run_step` call.
This callback is always enabled by the trainer, and you wouldn't need to use it.
"""
def _setup_graph(self):
# ensure it exists
......@@ -69,12 +68,12 @@ class MaintainStepCounter(Callback):
gs_val = get_global_step_value()
if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.trainer.local_step
self._last_updated = self.local_step
def _before_run(self, _):
# increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
if self.local_step != self._last_updated:
self._last_updated = self.local_step
return self._fetches
else:
return None
......@@ -93,7 +92,7 @@ class ProgressBar(Callback):
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
def _before_train(self):
self._last_updated = self.trainer.local_step
self._last_updated = self.local_step
self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
......@@ -104,10 +103,11 @@ class ProgressBar(Callback):
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _before_run(self, _):
if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
# update progress bar when local step changed (one step is finished)
if self.local_step != self._last_updated:
self._last_updated = self.local_step
if self.trainer.local_step == 0:
if self.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
return self._fetches
......@@ -121,7 +121,7 @@ class ProgressBar(Callback):
def _trigger_step(self):
self._bar.update()
if self.trainer.local_step == self._total - 1:
if self.local_step == self._total - 1:
self._bar.close()
def _after_train(self):
......
......@@ -60,7 +60,7 @@ class MergeAllSummaries(Callback):
def _before_run(self, ctx):
if self._run_alone:
return None
if self.trainer.local_step == self._total - 1:
if self.local_step == self._total - 1:
return self._fetches
return None
......
......@@ -12,7 +12,7 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class PeriodicTrigger(ProxyCallback):
"""
Schedule to trigger a callback every k steps or every k epochs by its ``_trigger()`` method.
Schedule to trigger a callback every k steps or every k epochs by its ``trigger()`` method.
"""
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
"""
......@@ -37,7 +37,7 @@ class PeriodicTrigger(ProxyCallback):
return
# trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished
if (self.trainer.local_step + 1) % self._step_k == 0:
if (self.local_step + 1) % self._step_k == 0:
self.cb.trigger()
def _trigger_epoch(self):
......
......@@ -16,13 +16,13 @@ from .predict import PredictorFactory
from .config import TrainConfig
from .monitor import Monitors, TrainingMonitor
from ..utils import logger
from ..utils.develop import deprecated, log_deprecated
from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
......@@ -211,13 +211,3 @@ class Trainer(object):
@deprecated("Use get_predictors instead!", "2017-05-20")
def get_predict_funcs(self, input_names, output_names, n):
return self.get_predictors(input_names, output_names, n)
@deprecated("Don't need to call it any more!", "2017-03-20")
def _setup_predictor_factory(self):
pass
# back-compat
class MultiPredictorTowerTrainer(Trainer):
def __init__(self, *args, **kwargs):
log_deprecated("MultiPredictorTowerTrainer", "Just remove it instead.", "2017-03-21")
......@@ -162,17 +162,6 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
log_deprecated("config.set_tower", "Set config.tower or config.nr_tower directly.", "2017-03-15")
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower:
tower = list(range(nr_tower))
else:
if isinstance(tower, int):
tower = list(range(tower))
self.tower = tower
assert isinstance(self.tower, list)
@property
def nr_tower(self):
return len(self.tower)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment