Commit 1cd4a380 authored by Yuxin Wu's avatar Yuxin Wu

Loaded epoch number doesn't take effect before training has started.

parent ab8503e8
...@@ -251,6 +251,7 @@ if __name__ == '__main__': ...@@ -251,6 +251,7 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DEPTH = args.depth DEPTH = args.depth
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval: if args.eval:
...@@ -258,12 +259,12 @@ if __name__ == '__main__': ...@@ -258,12 +259,12 @@ if __name__ == '__main__':
eval_on_ILSVRC12(args.load, args.data) eval_on_ILSVRC12(args.load, args.data)
sys.exit() sys.exit()
NR_GPU = len(args.gpu.split(',')) NR_GPU = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH))) os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH)))
logger.info("Batch size per GPU: " + str(BATCH_SIZE)) logger.info("Running on {} GPUs. Batch size per GPU: {}".format(NR_GPU, BATCH_SIZE))
config = get_config(fake=args.fake, data_format=args.data_format) config = get_config(fake=args.fake, data_format=args.data_format)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -58,11 +58,19 @@ class Trainer(object): ...@@ -58,11 +58,19 @@ class Trainer(object):
self.config = config self.config = config
self.model = config.model self.model = config.model
self.epoch_num = self.config.starting_epoch - 1
self.local_step = -1 self.local_step = -1
self._callbacks = [] self._callbacks = []
self.monitors = [] self.monitors = []
self._epoch_num = None
@property
def epoch_num(self):
if self._epoch_num is not None:
# has started training
return self._epoch_num
else:
return self.config.starting_epoch - 1
def register_callback(self, cb): def register_callback(self, cb):
""" """
...@@ -170,9 +178,9 @@ class Trainer(object): ...@@ -170,9 +178,9 @@ class Trainer(object):
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly # refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value() self._starting_step = get_global_step_value()
for self.epoch_num in range( for self._epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self._epoch_num))
start_time = time.time() start_time = time.time()
self._callbacks.before_epoch() self._callbacks.before_epoch()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
...@@ -182,7 +190,7 @@ class Trainer(object): ...@@ -182,7 +190,7 @@ class Trainer(object):
self._callbacks.trigger_step() self._callbacks.trigger_step()
self._callbacks.after_epoch() self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self._epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch() self._trigger_epoch()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment