Commit cb99d524 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Add functions to simplify the usage of tfSlim (#81)

This fetches all regularization terms, iff tfSlim is used.
Btachnorm-updates are also automatically performed on the main tower
similiar to the current implementation for the included batch-norm
layer.
parent d50341b8
......@@ -113,9 +113,44 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
def get_cost(self):
"""
Return the cost tensor in the graph. Called by some of the :class:`tensorpack.train.Trainer` which
assumes single-cost models.
"""
return self._get_cost()
assumes single-cost models. Apply tfSlim modifications.
"""
# current scope
scope = tf.get_variable_scope()
# the model cost so far
cost = self._get_cost()
# In contrast to this lib, when using tfSlim the user expect
# "with slim.arg_scope([...], weights_regularizer=slim.l2_regularizer(0.001)"
# to regularize these layers automatically. Note, this already contains the multiplier!
regulization_losses = 0
# try to prevent regEx error, iff scope name is empty ("")
try:
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope))
except Exception:
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
# TODO: check if "scope=scope" should be used here too
if len(regulization_losses) > 0:
cost += tf.add_n(regulization_losses, name="regularize_loss")
# As these batch-norm statistics quickly accumulate, there is no significant loss of accuracy
# if only the main tower handles all batch-normalization updates, which are then shared across
# the towers
ctx = get_current_tower_context()
if ctx is not None and ctx.is_main_training_tower:
# if there is no entry in tf.GraphKeys.UPDATE_OPS, then there is a regEx exception
try:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope))
except Exception:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
if non_grad_updates:
with tf.control_dependencies(non_grad_updates):
barrier = tf.control_flow_ops.no_op(name='batchnorm_barrier')
cost = tf.control_flow_ops.with_dependencies([barrier], cost)
return cost
def _get_cost(self, *args):
return self.cost
......
......@@ -175,6 +175,8 @@ class Trainer(object):
self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
except KeyboardInterrupt:
logger.info("Detected Ctrl+C and shutdown training.")
except:
raise
finally:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment