Commit 4aaf06ca authored by Yuxin Wu's avatar Yuxin Wu

separate setup_graph & before_train

parent e69034b5
......@@ -75,7 +75,6 @@ class Model(ModelDesc):
l = c2 + l
return l
l = conv('conv0', image, 16, 1)
l = BatchNorm('bn0', l, is_training)
l = tf.nn.relu(l)
......@@ -113,7 +112,7 @@ class Model(ModelDesc):
#wd_cost = regularize_cost('.*/W', l2_regularizer(0.0002), name='regularize_loss')
wd_w = tf.train.exponential_decay(0.0001, get_global_step_var(),
960000, 0.5, True)
wd_cost = wd_w * regularize_cost('.*/W', tf.nn.l2_loss)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor W
......
......@@ -28,18 +28,28 @@ class Callback(object):
Default is `TrainCallbackType()`
"""
def before_train(self, trainer):
def before_train(self):
"""
Called before starting iterative training.
Called right before the first iteration.
"""
self._before_train()
def _before_train(self):
pass
def setup_graph(self, trainer):
"""
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
:param trainer: a :class:`train.Trainer` instance
"""
self.trainer = trainer
self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch
self._before_train()
self._setup_graph()
def _before_train(self):
def _setup_graph(self):
pass
def after_train(self):
......
......@@ -23,7 +23,7 @@ class ModelSaver(Callback):
self.keep_recent = keep_recent
self.keep_freq = keep_freq
def _before_train(self):
def _setup_graph(self):
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
var_list=ModelSaver._get_vars(),
......
......@@ -66,7 +66,7 @@ class TestCallbackContext(object):
self.sess = None
@contextmanager
def before_train_context(self, trainer):
def create_context(self, trainer):
if self.sess is None:
with create_test_session(trainer) as sess:
self.sess = sess
......@@ -88,7 +88,7 @@ class TestCallbackContext(object):
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
@contextmanager
def trigger_epoch_context(self):
def test_context(self):
with self.graph.as_default(), self.sess.as_default():
yield
......@@ -110,13 +110,21 @@ class Callbacks(Callback):
self.cbs = cbs
self.test_callback_context = TestCallbackContext()
def _setup_graph(self):
for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType):
cb.setup_graph(self.trainer)
else:
with self.test_callback_context.create_context(self.trainer):
cb.setup_graph(self.trainer)
def _before_train(self):
for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType):
cb.before_train(self.trainer)
cb.before_train()
else:
with self.test_callback_context.before_train_context(self.trainer):
cb.before_train(self.trainer)
with self.test_callback_context.test_context():
cb.before_train()
def _after_train(self):
for cb in self.cbs:
......@@ -141,7 +149,7 @@ class Callbacks(Callback):
with tm.timed_callback('restore checkpoint'):
self.test_callback_context.restore_checkpoint()
test_sess_restored = True
with self.test_callback_context.trigger_epoch_context(), \
with self.test_callback_context.test_context(), \
tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
tm.log()
......@@ -30,7 +30,7 @@ class HyperParamSetter(Callback):
self.shape = shape
self.last_value = None
def _before_train(self):
def _setup_graph(self):
all_vars = tf.all_variables()
for v in all_vars:
if v.name == self.var_name:
......@@ -59,6 +59,12 @@ class HyperParamSetter(Callback):
pass
def _trigger_epoch(self):
self._set_param()
def _before_train(self):
self._set_param()
def _set_param(self):
v = self.get_current_value()
if v is not None:
self.assign_op.eval(feed_dict={self.val_holder:v})
......
......@@ -70,8 +70,6 @@ class StatHolder(object):
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
with open(tmp_filename, 'w') as f:
import IPython;
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
......
......@@ -81,7 +81,7 @@ class Trainer(object):
self._init_summary()
get_global_step_var() # ensure there is such var, before finalizing the graph
callbacks = self.config.callbacks
callbacks.before_train(self)
callbacks.setup_graph(self)
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
self._start_all_threads()
......@@ -91,6 +91,7 @@ class Trainer(object):
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {}, global_step={}'.format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment