Commit 9ce85aae authored by Yuxin Wu's avatar Yuxin Wu

Fix the loading of json stats (fix #904 introduced by 4a917d6e)

parent 36b05bb7
...@@ -78,6 +78,9 @@ class ScalarStats(Inferencer): ...@@ -78,6 +78,9 @@ class ScalarStats(Inferencer):
""" """
Statistics of some scalar tensor. Statistics of some scalar tensor.
The value will be averaged over all given datapoints. The value will be averaged over all given datapoints.
Note that the average of accuracy over all batches is not necessarily the
accuracy of the whole dataset. See :class:`ClassificationError` for details.
""" """
def __init__(self, names, prefix='validation'): def __init__(self, names, prefix='validation'):
......
...@@ -306,7 +306,13 @@ class JSONWriter(TrainingMonitor): ...@@ -306,7 +306,13 @@ class JSONWriter(TrainingMonitor):
except Exception: except Exception:
return None return None
# initialize the stats here, because before_train from other callbacks may use it
def _setup_graph(self): def _setup_graph(self):
self._stats = []
self._stat_now = {}
self._last_gs = -1
def _before_train(self):
stats = JSONWriter.load_existing_json() stats = JSONWriter.load_existing_json()
self._fname = os.path.join(logger.get_logger_dir(), JSONWriter.FILENAME) self._fname = os.path.join(logger.get_logger_dir(), JSONWriter.FILENAME)
if stats is not None: if stats is not None:
...@@ -315,33 +321,27 @@ class JSONWriter(TrainingMonitor): ...@@ -315,33 +321,27 @@ class JSONWriter(TrainingMonitor):
except Exception: except Exception:
epoch = None epoch = None
# check against the current training settings
# therefore this logic needs to be in before_train stage
starting_epoch = self.trainer.loop.starting_epoch starting_epoch = self.trainer.loop.starting_epoch
if epoch is None or epoch == starting_epoch: if epoch is None or epoch == starting_epoch:
logger.info("Found existing JSON inside {}, will append to it.".format(logger.get_logger_dir())) logger.info("Found existing JSON inside {}, will append to it.".format(logger.get_logger_dir()))
self._stats = stats self._stats = stats
else: else:
logger.warn( logger.warn(
"History epoch value {} from JSON is not the predecessor of the starting_epoch value {}".format( "History epoch={} from JSON is not the predecessor of the current starting_epoch={}".format(
epoch - 1, starting_epoch)) epoch - 1, starting_epoch))
logger.warn("If you want to resume old training, either use `AutoResumeTrainConfig` " logger.warn("If you want to resume old training, either use `AutoResumeTrainConfig` "
"or correctly set the starting_epoch yourself to avoid inconsistency. " "or correctly set the new starting_epoch yourself to avoid inconsistency. ")
"Epoch number will not be automatically loaded by JSONWriter.")
backup_fname = JSONWriter.FILENAME + '.' + datetime.now().strftime('%m%d-%H%M%S') backup_fname = JSONWriter.FILENAME + '.' + datetime.now().strftime('%m%d-%H%M%S')
backup_fname = os.path.join(logger.get_logger_dir(), backup_fname) backup_fname = os.path.join(logger.get_logger_dir(), backup_fname)
logger.warn("Now, we will start training at epoch {} and backup old json to {}".format( logger.warn("Now, we will train with starting_epoch={} and backup old json to {}".format(
self.trainer.loop.starting_epoch, backup_fname)) self.trainer.loop.starting_epoch, backup_fname))
shutil.move(self._fname, backup_fname) shutil.move(self._fname, backup_fname)
self._stats = []
else:
self._stats = []
self._stat_now = {}
self._last_gs = -1
# in case we have something to log here. # in case we have something to log here.
def _before_train(self):
self._trigger() self._trigger()
def _trigger_step(self): def _trigger_step(self):
......
...@@ -85,7 +85,7 @@ def _parse_meta(filename, cifar_classnum): ...@@ -85,7 +85,7 @@ def _parse_meta(filename, cifar_classnum):
class CifarBase(RNGDataFlow): class CifarBase(RNGDataFlow):
def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10): def __init__(self, train_or_test, shuffle=None, dir=None, cifar_classnum=10):
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
...@@ -104,6 +104,9 @@ class CifarBase(RNGDataFlow): ...@@ -104,6 +104,9 @@ class CifarBase(RNGDataFlow):
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum) self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir self.dir = dir
if shuffle is None:
shuffle = train_or_test == 'train'
self.shuffle = shuffle self.shuffle = shuffle
def __len__(self): def __len__(self):
...@@ -149,18 +152,18 @@ class Cifar10(CifarBase): ...@@ -149,18 +152,18 @@ class Cifar10(CifarBase):
image is 32x32x3 in the range [0,255]. image is 32x32x3 in the range [0,255].
label is an int. label is an int.
""" """
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=None, dir=None):
""" """
Args: Args:
train_or_test (str): either 'train' or 'test'. train_or_test (str): either 'train' or 'test'.
shuffle (bool): shuffle the dataset. shuffle (bool): shuffle the dataset, default to shuffle in training
""" """
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10) super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase): class Cifar100(CifarBase):
""" Similar to Cifar10""" """ Similar to Cifar10"""
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=None, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100) super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
......
...@@ -222,7 +222,7 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -222,7 +222,7 @@ class AutoResumeTrainConfig(TrainConfig):
if last_epoch is not None: if last_epoch is not None:
now_epoch = last_epoch + 1 now_epoch = last_epoch + 1
logger.info("Found history statistics from JSON. " logger.info("Found history statistics from JSON. "
"Overwrite the starting epoch to epoch #{}.".format(now_epoch)) "Setting starting_epoch to {}.".format(now_epoch))
kwargs['starting_epoch'] = now_epoch kwargs['starting_epoch'] = now_epoch
super(AutoResumeTrainConfig, self).__init__(**kwargs) super(AutoResumeTrainConfig, self).__init__(**kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment