Commit bc0b7c63 authored by Yuxin Wu's avatar Yuxin Wu

compatibility bug fix

parent a9dd0b8e
......@@ -8,7 +8,7 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`.
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/a9dd0b8ec34209ab86a92875589dbbc4716e73ef).
* 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in
TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe).
* 2017/01/25. `TrainConfig(callbacks=)` now takes a list of `Callback` instances. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/243e957fe6d62a0cfb5728bd77fb3e005d6603e4)
......
......@@ -60,20 +60,20 @@ class Callbacks(Callback):
assert isinstance(cb, Callback), cb.__class__
# move "StatPrinter" to the last
# TODO don't need to manually move in the future.
found = False
for idx, cb in enumerate(cbs):
if isinstance(cb, StatPrinter):
if found:
raise ValueError("Callbacks cannot contain two StatPrinter!")
sp = cb
cbs.remove(sp)
cbs.append(sp)
if idx != len(cbs) - 1:
logger.warn("StatPrinter should appear as the last element of callbacks! "
"This is now fixed automatically, but may not work in the future.")
found = True
if not found:
break
else:
raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!")
nr_printer = sum([int(isinstance(cb, StatPrinter)) for cb in cbs])
if nr_printer != 1:
raise ValueError("Callbacks must contain one StatPrinter!")
self.cbs = cbs
self._extra_fetches_cache = None
......
......@@ -9,7 +9,6 @@ import json
from .base import Callback
from .trigger import Triggerable
from ..utils import logger
from ..tfutils.common import get_global_step_value
__all__ = ['StatHolder', 'StatPrinter', 'SendStat']
......@@ -135,8 +134,9 @@ class StatPrinter(Callback):
def _trigger_epoch(self):
# by default, add this two stat
self._stat_holder.add_stat('global_step', get_global_step_value())
self._stat_holder.add_stat('global_step', self.global_step)
self._stat_holder.finalize()
# this is for the next group of stat
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment