Commit 89bcdd10 authored by Yuxin Wu's avatar Yuxin Wu

add local_step as callback property. remove some legacy code.

parent 762e4dcc
...@@ -17,6 +17,7 @@ class Callback(object): ...@@ -17,6 +17,7 @@ class Callback(object):
Attributes: Attributes:
epoch_num(int): the number of the current epoch. epoch_num(int): the number of the current epoch.
global_step(int): the number of global steps that have finished. global_step(int): the number of global steps that have finished.
local_step(int): the local steps within the current epoch.
trainer(Trainer): the trainer. trainer(Trainer): the trainer.
graph(tf.Graph): the graph. graph(tf.Graph): the graph.
...@@ -157,6 +158,10 @@ class Callback(object): ...@@ -157,6 +158,10 @@ class Callback(object):
def global_step(self): def global_step(self):
return self.trainer.global_step return self.trainer.global_step
@property
def local_step(self):
return self.trainer.local_step
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
......
...@@ -49,9 +49,8 @@ class StepTensorPrinter(Callback): ...@@ -49,9 +49,8 @@ class StepTensorPrinter(Callback):
class MaintainStepCounter(Callback): class MaintainStepCounter(Callback):
""" """
It maintains the global step in the graph and also creates the local step tensor. It maintains the global step in the graph, making sure it's increased by one in every `run_step` call.
This callback is always enabled by the trainer, and you wouldn't need to This callback is always enabled by the trainer, and you wouldn't need to use it.
use it.
""" """
def _setup_graph(self): def _setup_graph(self):
# ensure it exists # ensure it exists
...@@ -69,12 +68,12 @@ class MaintainStepCounter(Callback): ...@@ -69,12 +68,12 @@ class MaintainStepCounter(Callback):
gs_val = get_global_step_value() gs_val = get_global_step_value()
if gs_val != 0: if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val)) logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.trainer.local_step self._last_updated = self.local_step
def _before_run(self, _): def _before_run(self, _):
# increase global_step, when trainer.local_step changed # increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated: if self.local_step != self._last_updated:
self._last_updated = self.trainer.local_step self._last_updated = self.local_step
return self._fetches return self._fetches
else: else:
return None return None
...@@ -93,7 +92,7 @@ class ProgressBar(Callback): ...@@ -93,7 +92,7 @@ class ProgressBar(Callback):
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
def _before_train(self): def _before_train(self):
self._last_updated = self.trainer.local_step self._last_updated = self.local_step
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
...@@ -104,10 +103,11 @@ class ProgressBar(Callback): ...@@ -104,10 +103,11 @@ class ProgressBar(Callback):
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _before_run(self, _): def _before_run(self, _):
if self.trainer.local_step != self._last_updated: # update progress bar when local step changed (one step is finished)
self._last_updated = self.trainer.local_step if self.local_step != self._last_updated:
self._last_updated = self.local_step
if self.trainer.local_step == 0: if self.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
return self._fetches return self._fetches
...@@ -121,7 +121,7 @@ class ProgressBar(Callback): ...@@ -121,7 +121,7 @@ class ProgressBar(Callback):
def _trigger_step(self): def _trigger_step(self):
self._bar.update() self._bar.update()
if self.trainer.local_step == self._total - 1: if self.local_step == self._total - 1:
self._bar.close() self._bar.close()
def _after_train(self): def _after_train(self):
......
...@@ -60,7 +60,7 @@ class MergeAllSummaries(Callback): ...@@ -60,7 +60,7 @@ class MergeAllSummaries(Callback):
def _before_run(self, ctx): def _before_run(self, ctx):
if self._run_alone: if self._run_alone:
return None return None
if self.trainer.local_step == self._total - 1: if self.local_step == self._total - 1:
return self._fetches return self._fetches
return None return None
......
...@@ -12,7 +12,7 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback'] ...@@ -12,7 +12,7 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class PeriodicTrigger(ProxyCallback): class PeriodicTrigger(ProxyCallback):
""" """
Schedule to trigger a callback every k steps or every k epochs by its ``_trigger()`` method. Schedule to trigger a callback every k steps or every k epochs by its ``trigger()`` method.
""" """
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None): def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
""" """
...@@ -37,7 +37,7 @@ class PeriodicTrigger(ProxyCallback): ...@@ -37,7 +37,7 @@ class PeriodicTrigger(ProxyCallback):
return return
# trigger_step is triggered after run_step, so # trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished # local_step + 1 is the number of step that have finished
if (self.trainer.local_step + 1) % self._step_k == 0: if (self.local_step + 1) % self._step_k == 0:
self.cb.trigger() self.cb.trigger()
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -16,13 +16,13 @@ from .predict import PredictorFactory ...@@ -16,13 +16,13 @@ from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from .monitor import Monitors, TrainingMonitor from .monitor import Monitors, TrainingMonitor
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated, log_deprecated from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer'] __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException): class StopTraining(BaseException):
...@@ -211,13 +211,3 @@ class Trainer(object): ...@@ -211,13 +211,3 @@ class Trainer(object):
@deprecated("Use get_predictors instead!", "2017-05-20") @deprecated("Use get_predictors instead!", "2017-05-20")
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
return self.get_predictors(input_names, output_names, n) return self.get_predictors(input_names, output_names, n)
@deprecated("Don't need to call it any more!", "2017-03-20")
def _setup_predictor_factory(self):
pass
# back-compat
class MultiPredictorTowerTrainer(Trainer):
def __init__(self, *args, **kwargs):
log_deprecated("MultiPredictorTowerTrainer", "Just remove it instead.", "2017-03-21")
...@@ -162,17 +162,6 @@ class TrainConfig(object): ...@@ -162,17 +162,6 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
log_deprecated("config.set_tower", "Set config.tower or config.nr_tower directly.", "2017-03-15")
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower:
tower = list(range(nr_tower))
else:
if isinstance(tower, int):
tower = list(range(tower))
self.tower = tower
assert isinstance(self.tower, list)
@property @property
def nr_tower(self): def nr_tower(self):
return len(self.tower) return len(self.tower)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment