Commit eb57892e authored by Yuxin Wu's avatar Yuxin Wu

update docs and clean up some legacy

parent efc8a5f0
......@@ -44,8 +44,20 @@ TrainConfig(
MovingAverageSummary(),
# draw a nice progress bar
ProgressBar(),
# run `tf.summary.merge_all` and save results every epoch
# run `tf.summary.merge_all` every epoch and send results to monitors
MergeAllSummaries(),
]
],
monitors=[ # monitors are a special kind of callbacks. these are also enabled by default
# write all monitor data to tensorboard
TFSummaryWriter(),
# write all scalar data to a json file, for easy parsing
JSONWriter(),
# print all scalar data every epoch (can be configured differently)
ScalarPrinter(),
]
)
```
Notice that callbacks really cover every detail of training, ranging from graph operations to the progress bar.
This means you can customize every part of the training to your preference, e.g. display something
different in the progress bar, evaluating part of the summaries at a different frequency, etc.
......@@ -4,11 +4,9 @@
import numpy as np
from abc import ABCMeta, abstractmethod
import sys
import six
from six.moves import zip
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name
......@@ -140,14 +138,10 @@ class ClassificationError(Inferencer):
def _datapoint(self, outputs):
vec = outputs[0]
if vec.ndim == 0:
logger.error("[DEPRECATED] use a 'wrong vector' for ClassificationError instead of nr_wrong. Exiting..")
sys.exit(1)
else:
# TODO put shape assertion into inference-runner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_tensor_name)
batch_size = len(vec)
wrong = np.sum(vec)
# TODO put shape assertion into inference-runner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_tensor_name)
batch_size = len(vec)
wrong = np.sum(vec)
self.err_stat.feed(wrong, batch_size)
def _after_inference(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: steps.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Some common step callbacks. """
......
......@@ -4,10 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ProxyCallback, Callback
from ..utils.develop import log_deprecated
__all__ = ['PeriodicTrigger', 'PeriodicCallback']
__all__ = ['PeriodicTrigger']
class PeriodicTrigger(ProxyCallback):
......@@ -48,30 +46,3 @@ class PeriodicTrigger(ProxyCallback):
def __str__(self):
return "PeriodicTrigger-" + str(self.cb)
class PeriodicCallback(ProxyCallback):
"""
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
This wrapper is legacy. It will only proxy the :meth:`trigger_step` method as-is.
To be able to schedule a callback more frequent than once per epoch, use :class:`PeriodicTrigger` instead.
"""
def __init__(self, cb, period):
"""
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
"""
super(PeriodicCallback, self).__init__(cb)
self.period = int(period)
log_deprecated("PeriodicCallback", "Use the more powerful `PeriodicTrigger`.")
def _trigger_epoch(self):
if self.epoch_num % self.period == 0:
self.cb.trigger_epoch()
def __str__(self):
return "Periodic-" + str(self.cb)
......@@ -9,7 +9,6 @@ from six.moves import queue, range
import tensorflow as tf
from ..utils import logger
from ..utils.develop import deprecated
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.model_utils import describe_model
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
......@@ -165,10 +164,6 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
for t in self.threads:
t.start()
@deprecated("Use 'start()' instead!", "2017-03-11")
def run(self): # temporarily for back-compatibility
self.start()
def put_task(self, dp, callback=None):
"""
Same as in :meth:`AsyncPredictorBase.put_task`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment