You need to sign in or sign up before continuing.
Commit 1cd4a380 authored by Yuxin Wu's avatar Yuxin Wu

Loaded epoch number doesn't take effect before training has started.

parent ab8503e8
...@@ -251,19 +251,20 @@ if __name__ == '__main__': ...@@ -251,19 +251,20 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DEPTH = args.depth DEPTH = args.depth
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval: if args.eval:
BATCH_SIZE = 128 # something that can run on one gpu BATCH_SIZE = 128 # something that can run on one gpu
eval_on_ILSVRC12(args.load, args.data) eval_on_ILSVRC12(args.load, args.data)
sys.exit() sys.exit()
NR_GPU = len(args.gpu.split(',')) NR_GPU = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH))) os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH)))
logger.info("Batch size per GPU: " + str(BATCH_SIZE)) logger.info("Running on {} GPUs. Batch size per GPU: {}".format(NR_GPU, BATCH_SIZE))
config = get_config(fake=args.fake, data_format=args.data_format) config = get_config(fake=args.fake, data_format=args.data_format)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -58,11 +58,19 @@ class Trainer(object): ...@@ -58,11 +58,19 @@ class Trainer(object):
self.config = config self.config = config
self.model = config.model self.model = config.model
self.epoch_num = self.config.starting_epoch - 1
self.local_step = -1 self.local_step = -1
self._callbacks = [] self._callbacks = []
self.monitors = [] self.monitors = []
self._epoch_num = None
@property
def epoch_num(self):
if self._epoch_num is not None:
# has started training
return self._epoch_num
else:
return self.config.starting_epoch - 1
def register_callback(self, cb): def register_callback(self, cb):
""" """
...@@ -170,9 +178,9 @@ class Trainer(object): ...@@ -170,9 +178,9 @@ class Trainer(object):
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly # refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value() self._starting_step = get_global_step_value()
for self.epoch_num in range( for self._epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self._epoch_num))
start_time = time.time() start_time = time.time()
self._callbacks.before_epoch() self._callbacks.before_epoch()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
...@@ -182,7 +190,7 @@ class Trainer(object): ...@@ -182,7 +190,7 @@ class Trainer(object):
self._callbacks.trigger_step() self._callbacks.trigger_step()
self._callbacks.after_epoch() self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self._epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch() self._trigger_epoch()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment