Commit 9ce85aae authored by Yuxin Wu's avatar Yuxin Wu

Fix the loading of json stats (fix #904 introduced by 4a917d6e)

parent 36b05bb7
......@@ -78,6 +78,9 @@ class ScalarStats(Inferencer):
"""
Statistics of some scalar tensor.
The value will be averaged over all given datapoints.
Note that the average of accuracy over all batches is not necessarily the
accuracy of the whole dataset. See :class:`ClassificationError` for details.
"""
def __init__(self, names, prefix='validation'):
......
......@@ -306,7 +306,13 @@ class JSONWriter(TrainingMonitor):
except Exception:
return None
# initialize the stats here, because before_train from other callbacks may use it
def _setup_graph(self):
self._stats = []
self._stat_now = {}
self._last_gs = -1
def _before_train(self):
stats = JSONWriter.load_existing_json()
self._fname = os.path.join(logger.get_logger_dir(), JSONWriter.FILENAME)
if stats is not None:
......@@ -315,33 +321,27 @@ class JSONWriter(TrainingMonitor):
except Exception:
epoch = None
# check against the current training settings
# therefore this logic needs to be in before_train stage
starting_epoch = self.trainer.loop.starting_epoch
if epoch is None or epoch == starting_epoch:
logger.info("Found existing JSON inside {}, will append to it.".format(logger.get_logger_dir()))
self._stats = stats
else:
logger.warn(
"History epoch value {} from JSON is not the predecessor of the starting_epoch value {}".format(
"History epoch={} from JSON is not the predecessor of the current starting_epoch={}".format(
epoch - 1, starting_epoch))
logger.warn("If you want to resume old training, either use `AutoResumeTrainConfig` "
"or correctly set the starting_epoch yourself to avoid inconsistency. "
"Epoch number will not be automatically loaded by JSONWriter.")
"or correctly set the new starting_epoch yourself to avoid inconsistency. ")
backup_fname = JSONWriter.FILENAME + '.' + datetime.now().strftime('%m%d-%H%M%S')
backup_fname = os.path.join(logger.get_logger_dir(), backup_fname)
logger.warn("Now, we will start training at epoch {} and backup old json to {}".format(
logger.warn("Now, we will train with starting_epoch={} and backup old json to {}".format(
self.trainer.loop.starting_epoch, backup_fname))
shutil.move(self._fname, backup_fname)
self._stats = []
else:
self._stats = []
self._stat_now = {}
self._last_gs = -1
# in case we have something to log here.
def _before_train(self):
# in case we have something to log here.
self._trigger()
def _trigger_step(self):
......
......@@ -85,7 +85,7 @@ def _parse_meta(filename, cifar_classnum):
class CifarBase(RNGDataFlow):
def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
def __init__(self, train_or_test, shuffle=None, dir=None, cifar_classnum=10):
assert train_or_test in ['train', 'test']
assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum
......@@ -104,6 +104,9 @@ class CifarBase(RNGDataFlow):
self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir
if shuffle is None:
shuffle = train_or_test == 'train'
self.shuffle = shuffle
def __len__(self):
......@@ -149,18 +152,18 @@ class Cifar10(CifarBase):
image is 32x32x3 in the range [0,255].
label is an int.
"""
def __init__(self, train_or_test, shuffle=True, dir=None):
def __init__(self, train_or_test, shuffle=None, dir=None):
"""
Args:
train_or_test (str): either 'train' or 'test'.
shuffle (bool): shuffle the dataset.
shuffle (bool): shuffle the dataset, default to shuffle in training
"""
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase):
""" Similar to Cifar10"""
def __init__(self, train_or_test, shuffle=True, dir=None):
def __init__(self, train_or_test, shuffle=None, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
......
......@@ -222,7 +222,7 @@ class AutoResumeTrainConfig(TrainConfig):
if last_epoch is not None:
now_epoch = last_epoch + 1
logger.info("Found history statistics from JSON. "
"Overwrite the starting epoch to epoch #{}.".format(now_epoch))
"Setting starting_epoch to {}.".format(now_epoch))
kwargs['starting_epoch'] = now_epoch
super(AutoResumeTrainConfig, self).__init__(**kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment