Commit 4aaf06ca authored by Yuxin Wu's avatar Yuxin Wu

separate setup_graph & before_train

parent e69034b5
...@@ -75,7 +75,6 @@ class Model(ModelDesc): ...@@ -75,7 +75,6 @@ class Model(ModelDesc):
l = c2 + l l = c2 + l
return l return l
l = conv('conv0', image, 16, 1) l = conv('conv0', image, 16, 1)
l = BatchNorm('bn0', l, is_training) l = BatchNorm('bn0', l, is_training)
l = tf.nn.relu(l) l = tf.nn.relu(l)
...@@ -113,7 +112,7 @@ class Model(ModelDesc): ...@@ -113,7 +112,7 @@ class Model(ModelDesc):
#wd_cost = regularize_cost('.*/W', l2_regularizer(0.0002), name='regularize_loss') #wd_cost = regularize_cost('.*/W', l2_regularizer(0.0002), name='regularize_loss')
wd_w = tf.train.exponential_decay(0.0001, get_global_step_var(), wd_w = tf.train.exponential_decay(0.0001, get_global_step_var(),
960000, 0.5, True) 960000, 0.5, True)
wd_cost = wd_w * regularize_cost('.*/W', tf.nn.l2_loss) wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor W add_param_summary([('.*/W', ['histogram'])]) # monitor W
......
...@@ -28,18 +28,28 @@ class Callback(object): ...@@ -28,18 +28,28 @@ class Callback(object):
Default is `TrainCallbackType()` Default is `TrainCallbackType()`
""" """
def before_train(self, trainer): def before_train(self):
""" """
Called before starting iterative training. Called right before the first iteration.
"""
self._before_train()
def _before_train(self):
pass
def setup_graph(self, trainer):
"""
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
:param trainer: a :class:`train.Trainer` instance :param trainer: a :class:`train.Trainer` instance
""" """
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch self.epoch_num = self.trainer.config.starting_epoch
self._before_train() self._setup_graph()
def _before_train(self): def _setup_graph(self):
pass pass
def after_train(self): def after_train(self):
......
...@@ -23,7 +23,7 @@ class ModelSaver(Callback): ...@@ -23,7 +23,7 @@ class ModelSaver(Callback):
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
def _before_train(self): def _setup_graph(self):
self.path = os.path.join(logger.LOG_DIR, 'model') self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=ModelSaver._get_vars(), var_list=ModelSaver._get_vars(),
......
...@@ -66,7 +66,7 @@ class TestCallbackContext(object): ...@@ -66,7 +66,7 @@ class TestCallbackContext(object):
self.sess = None self.sess = None
@contextmanager @contextmanager
def before_train_context(self, trainer): def create_context(self, trainer):
if self.sess is None: if self.sess is None:
with create_test_session(trainer) as sess: with create_test_session(trainer) as sess:
self.sess = sess self.sess = sess
...@@ -88,7 +88,7 @@ class TestCallbackContext(object): ...@@ -88,7 +88,7 @@ class TestCallbackContext(object):
self.saver.restore(self.sess, ckpt.model_checkpoint_path) self.saver.restore(self.sess, ckpt.model_checkpoint_path)
@contextmanager @contextmanager
def trigger_epoch_context(self): def test_context(self):
with self.graph.as_default(), self.sess.as_default(): with self.graph.as_default(), self.sess.as_default():
yield yield
...@@ -110,13 +110,21 @@ class Callbacks(Callback): ...@@ -110,13 +110,21 @@ class Callbacks(Callback):
self.cbs = cbs self.cbs = cbs
self.test_callback_context = TestCallbackContext() self.test_callback_context = TestCallbackContext()
def _setup_graph(self):
for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType):
cb.setup_graph(self.trainer)
else:
with self.test_callback_context.create_context(self.trainer):
cb.setup_graph(self.trainer)
def _before_train(self): def _before_train(self):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType): if isinstance(cb.type, TrainCallbackType):
cb.before_train(self.trainer) cb.before_train()
else: else:
with self.test_callback_context.before_train_context(self.trainer): with self.test_callback_context.test_context():
cb.before_train(self.trainer) cb.before_train()
def _after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
...@@ -141,7 +149,7 @@ class Callbacks(Callback): ...@@ -141,7 +149,7 @@ class Callbacks(Callback):
with tm.timed_callback('restore checkpoint'): with tm.timed_callback('restore checkpoint'):
self.test_callback_context.restore_checkpoint() self.test_callback_context.restore_checkpoint()
test_sess_restored = True test_sess_restored = True
with self.test_callback_context.trigger_epoch_context(), \ with self.test_callback_context.test_context(), \
tm.timed_callback(type(cb).__name__): tm.timed_callback(type(cb).__name__):
cb.trigger_epoch() cb.trigger_epoch()
tm.log() tm.log()
...@@ -30,7 +30,7 @@ class HyperParamSetter(Callback): ...@@ -30,7 +30,7 @@ class HyperParamSetter(Callback):
self.shape = shape self.shape = shape
self.last_value = None self.last_value = None
def _before_train(self): def _setup_graph(self):
all_vars = tf.all_variables() all_vars = tf.all_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
...@@ -59,6 +59,12 @@ class HyperParamSetter(Callback): ...@@ -59,6 +59,12 @@ class HyperParamSetter(Callback):
pass pass
def _trigger_epoch(self): def _trigger_epoch(self):
self._set_param()
def _before_train(self):
self._set_param()
def _set_param(self):
v = self.get_current_value() v = self.get_current_value()
if v is not None: if v is not None:
self.assign_op.eval(feed_dict={self.val_holder:v}) self.assign_op.eval(feed_dict={self.val_holder:v})
......
...@@ -70,8 +70,6 @@ class StatHolder(object): ...@@ -70,8 +70,6 @@ class StatHolder(object):
def _write_stat(self): def _write_stat(self):
tmp_filename = self.filename + '.tmp' tmp_filename = self.filename + '.tmp'
with open(tmp_filename, 'w') as f: with open(tmp_filename, 'w') as f:
import IPython;
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
json.dump(self.stat_history, f) json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename) os.rename(tmp_filename, self.filename)
......
...@@ -81,7 +81,7 @@ class Trainer(object): ...@@ -81,7 +81,7 @@ class Trainer(object):
self._init_summary() self._init_summary()
get_global_step_var() # ensure there is such var, before finalizing the graph get_global_step_var() # ensure there is such var, before finalizing the graph
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.before_train(self) callbacks.setup_graph(self)
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
self._start_all_threads() self._start_all_threads()
...@@ -91,6 +91,7 @@ class Trainer(object): ...@@ -91,6 +91,7 @@ class Trainer(object):
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1): for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {}, global_step={}'.format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment