Commit 827791cb authored by Yuxin Wu's avatar Yuxin Wu

fix import #96

parent 563b9cd6
...@@ -30,6 +30,7 @@ class StatHolder(object): ...@@ -30,6 +30,7 @@ class StatHolder(object):
self.log_dir = log_dir self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.json') self.filename = os.path.join(log_dir, 'stat.json')
if os.path.isfile(self.filename): if os.path.isfile(self.filename):
# TODO make a backup first?
logger.info("Found stats at {}, will append to it.".format(self.filename)) logger.info("Found stats at {}, will append to it.".format(self.filename))
with open(self.filename) as f: with open(self.filename) as f:
self.stat_history = json.load(f) self.stat_history = json.load(f)
...@@ -62,13 +63,16 @@ class StatHolder(object): ...@@ -62,13 +63,16 @@ class StatHolder(object):
def get_stat_now(self, key): def get_stat_now(self, key):
""" """
Return the value of a stat in the current epoch. Return the value of a stat in the current epoch.
Raises:
KeyError if the key hasn't been added in this epoch.
""" """
return self.stat_now[key] return self.stat_now[key]
def get_stat_history(self, key): def get_stat_history(self, key):
""" """
Returns: Returns:
list: all history of a stat. list: all history of a stat. Empty if there is not history of this name.
""" """
ret = [] ret = []
for h in self.stat_history: for h in self.stat_history:
...@@ -82,7 +86,8 @@ class StatHolder(object): ...@@ -82,7 +86,8 @@ class StatHolder(object):
def finalize(self): def finalize(self):
""" """
Called after finishing adding stats for this epoch. Will print and write stats to disk. Called after finishing adding stats for this epoch.
Will print and write stats to disk.
""" """
self._print_stat() self._print_stat()
self.stat_history.append(self.stat_now) self.stat_history.append(self.stat_now)
......
...@@ -16,7 +16,7 @@ from ..utils import logger, get_tqdm ...@@ -16,7 +16,7 @@ from ..utils import logger, get_tqdm
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
from .concurrency import MultiProcessQueuePredictWorker from .concurrency import MultiProcessQueuePredictWorker
from .common import PredictConfig from .config import PredictConfig
from .base import OfflinePredictor from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor', __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment