Commit 84ba85fd authored by Yuxin Wu's avatar Yuxin Wu

starting epoch

parent bf0ef32e
...@@ -29,7 +29,7 @@ class Callback(object): ...@@ -29,7 +29,7 @@ class Callback(object):
def before_train(self, trainer): def before_train(self, trainer):
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.epoch_num = 0 self.epoch_num = self.trainer.config.starting_epoch
self._before_train() self._before_train()
def _before_train(self): def _before_train(self):
...@@ -59,8 +59,8 @@ class Callback(object): ...@@ -59,8 +59,8 @@ class Callback(object):
""" """
epoch_num is the number of epoch finished. epoch_num is the number of epoch finished.
""" """
self.epoch_num += 1
self._trigger_epoch() self._trigger_epoch()
self.epoch_num += 1
def _trigger_epoch(self): def _trigger_epoch(self):
""" """
......
...@@ -77,7 +77,7 @@ class Trainer(object): ...@@ -77,7 +77,7 @@ class Trainer(object):
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
for epoch in range(1, self.config.max_epoch): for epoch in range(self.config.starting_epoch, self.config.max_epoch):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)): epoch, self.global_step + self.config.step_per_epoch)):
......
...@@ -29,8 +29,8 @@ class TrainConfig(object): ...@@ -29,8 +29,8 @@ class TrainConfig(object):
session_init: a tensorpack.utils.sessinit.SessionInit instance to session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session. initialize variables of a session. default to a new session.
model: a ModelDesc instance model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform starting_epoch: int. default to be 1.
in each epoch. step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
max_epoch: maximum number of epoch to run training. default to 100 max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1. nr_tower: int. number of towers. default to 1.
""" """
...@@ -50,6 +50,7 @@ class TrainConfig(object): ...@@ -50,6 +50,7 @@ class TrainConfig(object):
self.session_init = kwargs.pop('session_init', NewSession()) self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch')) self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.starting_epoch = int(kwargs.pop('starting_epoch', 1))
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1)) self.nr_tower = int(kwargs.pop('nr_tower', 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment