Commit b5be9481 authored by Yuxin Wu's avatar Yuxin Wu

improve logger message; add "skip_collection" in freeze_variables (fix #664)

parent 92ccfe0a
...@@ -31,11 +31,13 @@ Then it is a good time to open an issue. ...@@ -31,11 +31,13 @@ Then it is a good time to open an issue.
## How to freeze some variables in training ## How to freeze some variables in training
1. You can simply use `tf.stop_gradient` in your model code in some situations (e.g. to freeze first several layers). 1. Learn `tf.stop_gradient`. You can simply use `tf.stop_gradient` in your model code in many situations (e.g. to freeze first several layers).
2. [varreplace.freeze_variables](../modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) can wrap some variables with `tf.stop_gradient`. 2. [varreplace.freeze_variables](../modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) returns a context where variables are freezed.
Learn to use the `custom_getter` argument of `tf.variable_scope` to gain more control over what & how variables are freezed.
3. [ScaleGradient](../modules/tfutils.html#tensorpack.tfutils.gradproc.ScaleGradient) can be used to set the gradients of some variables to 0. 3. [ScaleGradient](../modules/tfutils.html#tensorpack.tfutils.gradproc.ScaleGradient) can be used to set the gradients of some variables to 0.
But it may be slow, since variables still have gradients.
Note that the above methods only prevent variables being updated by SGD. Note that the above methods only prevent variables being updated by SGD.
Some variables may be updated by other means, Some variables may be updated by other means,
......
...@@ -38,23 +38,36 @@ def remap_variables(fn): ...@@ -38,23 +38,36 @@ def remap_variables(fn):
return custom_getter_scope(custom_getter) return custom_getter_scope(custom_getter)
def freeze_variables(): def freeze_variables(stop_gradient=True, skip_collection=False):
""" """
Return a context, where all trainable variables (reused or not) returned by Return a context to freeze variables,
``get_variable`` will have no gradients (they will be wrapped by ``tf.stop_gradient``). by wrapping ``tf.get_variable`` with a custom getter.
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get It works by either applying ``tf.stop_gradient`` on the variables,
saved correctly. This is useful to fix certain variables for fine-tuning. or by keeping them out of the ``TRAINABLE_VARIABLES`` collection, or
both.
Example: Example:
.. code-block:: python .. code-block:: python
with varreplace.freeze_variable(): with varreplace.freeze_variable(stop_gradient=False, skip_collection=True):
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
Args:
stop_gradient (bool): if True, variables returned from `get_variable`
will be wrapped with `tf.stop_gradient` and therefore has no
gradient when used later. Note that the created variables may
still have gradient when accessed by other approaches (e.g.
by name, or by collection).
skip_collection (bool): if True, do not add the variable to
``TRAINABLE_VARIABLES`` collection. As a result they will not be
trained by default.
""" """
def custom_getter(getter, *args, **kwargs): def custom_getter(getter, *args, **kwargs):
trainable = kwargs.get('trainable', True) trainable = kwargs.get('trainable', True)
if skip_collection:
kwargs['trainable'] = False
v = getter(*args, **kwargs) v = getter(*args, **kwargs)
if trainable: if trainable and stop_gradient:
v = tf.stop_gradient(v) v = tf.stop_gradient(v)
return v return v
return custom_getter_scope(custom_getter) return custom_getter_scope(custom_getter)
...@@ -82,6 +82,14 @@ def set_logger_dir(dirname, action=None): ...@@ -82,6 +82,14 @@ def set_logger_dir(dirname, action=None):
dirname(str): log directory dirname(str): log directory
action(str): an action of ("k","d","q") to be performed action(str): an action of ("k","d","q") to be performed
when the directory exists. Will ask user by default. when the directory exists. Will ask user by default.
"d": delete the directory. Note that the deletion may fail when
the directory is used by tensorboard.
"k": keep the directory. This is useful when you resume from a
previous training and want the directory to look as if the
training was not interrupted.
Note that this option does not load old models or any other
old states for you. It simply does nothing.
""" """
global LOG_DIR, _FILE_HANDLER global LOG_DIR, _FILE_HANDLER
if _FILE_HANDLER: if _FILE_HANDLER:
...@@ -91,12 +99,12 @@ def set_logger_dir(dirname, action=None): ...@@ -91,12 +99,12 @@ def set_logger_dir(dirname, action=None):
if os.path.isdir(dirname) and len(os.listdir(dirname)): if os.path.isdir(dirname) and len(os.listdir(dirname)):
if not action: if not action:
_logger.warn("""\ _logger.warn("""\
Log directory {} exists! Please either backup/delete it, or use a new directory.""".format(dirname)) Log directory {} exists! Use 'd' to delete it. """.format(dirname))
_logger.warn("""\ _logger.warn("""\
If you're resuming from a previous run you can choose to keep it.""") If you're resuming from a previous run, you can choose to keep it.
_logger.info("Select Action: k (keep) / d (delete) / q (quit):") Press any other key to exit. """)
while not action: while not action:
action = input().lower().strip() action = input("Select Action: k (keep) / d (delete) / q (quit):").lower().strip()
act = action act = action
if act == 'b': if act == 'b':
backup_name = dirname + _get_time_str() backup_name = dirname + _get_time_str()
...@@ -109,10 +117,8 @@ If you're resuming from a previous run you can choose to keep it.""") ...@@ -109,10 +117,8 @@ If you're resuming from a previous run you can choose to keep it.""")
info("Use a new log directory {}".format(dirname)) # noqa: F821 info("Use a new log directory {}".format(dirname)) # noqa: F821
elif act == 'k': elif act == 'k':
pass pass
elif act == 'q':
raise OSError("Directory {} exits!".format(dirname))
else: else:
raise ValueError("Unknown action: {}".format(act)) raise OSError("Directory {} exits!".format(dirname))
LOG_DIR = dirname LOG_DIR = dirname
from .fs import mkdir_p from .fs import mkdir_p
mkdir_p(dirname) mkdir_p(dirname)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment