Commit 817a7ecf authored by Yuxin Wu's avatar Yuxin Wu

Do not trust trainable_variables create by Keras layers. Use `M.weights`...

Do not trust trainable_variables create by Keras layers. Use `M.weights` instead and print warnings about unknown variables. (#748)
parent 7388f508
......@@ -15,29 +15,42 @@ But some basic knowledge of how they work is useful:
Following the terminology in TensorFlow,
a __tower function__ is a callable that takes input tensors and adds __one replicate__ of the model to the graph.
Most types of neural-network training could fall into this category.
All trainers in tensorpack is a subclass of [TowerTrainer](../modules/train.html#tensorpack.train.TowerTrainer).
Most types of neural-network training could be described with this concept.
The concept of tower is used mainly to support:
1. Data-parallel multi-GPU training, where a replicate is built on each GPU.
2. Graph construction for inference, where a replicate is built under inference mode.
You'll provide a tower function to use `TowerTrainer`.
The function needs to follow some conventions:
A user needs to provide a tower function to use `TowerTrainer`.
In particular, when working with the `ModelDesc` interface, the `build_graph` method will be the tower function.
The tower function needs to follow some conventions:
1. __It might get called multiple times__ for data-parallel training or inference.
2. It has to respect variable collections:
* Only put variables __trainable by gradient descent__ into `TRAINABLE_VARIABLES`.
* Put variables that need to be saved into `MODEL_VARIABLES`.
3. It has to respect variable scopes:
* The name of any trainable variables created in the function must be like "variable_scope_name/variable/name".
Don't depend on name_scope's name. Don't use variable_scope's name twice.
* The creation of any trainable variables must respect variable reuse.
To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in the function.
For non-trainable variables, it's OK to use `tf.Variable` to force creation of new variables in each tower.
4. It will always be called under a `TowerContext`.
which will contain information about training/inference mode, reuse, etc.
1. It will always be called under a `TowerContext`.
which will contain information about reuse, training/inference, scope name, etc.
2. __It might get called multiple times__ for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in the function, unless you want to force creation of new variables.
These conventions are easy to follow, and most layer wrappers (e.g.,
tf.layers/slim/tensorlayer) do follow them. Note that certain Keras layers do not
follow these conventions and may crash if used within tensorpack.
In particular, when working with the `ModelDesc` interface, its `build_graph` method will be the tower function.
It's possible to write ones that are not, but all existing trainers in
tensorpack are subclass of [TowerTrainer](../modules/train.html#tensorpack.train.TowerTrainer).
### MultiGPU Trainers
For data-parallel multi-GPU training, different [multi-GPU trainers](../modules/train.html)
implement different parallel logic.
implement different distribution strategies.
They take care of device placement, gradient averaging and synchronoization
in the efficient way and all reach the same performance as the
[official TF benchmarks](https://www.tensorflow.org/performance/benchmarks).
......@@ -46,11 +59,11 @@ It takes only one line of code change to use them.
Note some __common problems__ when using these trainers:
1. In each iteration, all GPUs (all replicates of the model) take tensors from the `InputSource`,
instead of take one for all and split.
So the total batch size would become ``(batch size of InputSource/DataFlow) * #GPU``.
instead of taking one for all and split.
So the total batch size would become ``(batch size of InputSource) * #GPU``.
Splitting a tensor for data-parallel training makes no sense at all, only to put unnecessary shape constraints on the data.
By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously.
2. The tower function (your model code) will get called multipile times.
You'll need to be very careful when modifying global states in those functions, e.g. adding ops to TF collections.
As a result, you'll need to be careful when modifying global states in those functions, e.g. adding ops to TF collections.
......@@ -32,3 +32,11 @@ It has:
+ Still slightly slower than native tensorpack examples.
+ Good accuracy (same as [tensorpack ResNet example](../ResNet))
### Note:
Keras support is __not official__. Keras does not use variable scopes or variable
collections, which contradicts with tensorpack trainers.
Therefore, not all Keras layers are supported in tensorpack.
These simple examples can run within tensorpack smoothly, but note that a future version
of Keras may still break them (unlikely, though).
......@@ -16,6 +16,7 @@ from ..callbacks import (
ScalarStats)
from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.tower import get_current_tower_context
from ..tfutils.scope_utils import cached_name_scope
from ..tfutils.summary import add_moving_summary
......@@ -37,8 +38,8 @@ def _check_name(tensor, name):
class KerasModelCaller(object):
"""
Keras model doesn't support vs reuse.
This is hack to mimic reuse.
Keras model doesn't support variable scope reuse.
This is a hack to mimic reuse.
"""
def __init__(self, get_model):
self.get_model = get_model
......@@ -53,20 +54,39 @@ class KerasModelCaller(object):
output tensors of this tower, evaluated with the input tensors.
"""
reuse = tf.get_variable_scope().reuse
old_trainable_names = set([x.name for x in tf.trainable_variables()])
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
try:
if self.cached_model is None:
assert not reuse
self.cached_model = self.get_model(*input_tensors)
return self.cached_model.outputs
if reuse:
M = self.cached_model = self.get_model(*input_tensors)
return M.outputs
elif reuse:
# use the cached Keras model to mimic reuse
# NOTE: ctx.is_training won't be useful inside model,
# because inference will always use the cached Keras model
return self.cached_model.call(input_tensors)
M = self.cached_model
return M.call(input_tensors)
else:
# create new Keras model if not reuse
M = self.get_model(*input_tensors)
return M.outputs
finally:
added_trainable_names = set([x.name for x in tf.trainable_variables()])
restore_collection(trainable_backup)
for v in M.weights:
# In Keras, the collection is not respected and could contain non-trainable vars.
# We put M.weights into the collection instead.
if v.name not in old_trainable_names:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
new_trainable_names = set([x.name for x in tf.trainable_variables()])
for n in added_trainable_names:
if n not in new_trainable_names:
logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack.".format(n))
# Keras needs an extra input if learning_phase is used by the model
......
......@@ -275,26 +275,36 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
# literally all variables, because it's better to sync optimizer-internal variables as well
all_vars = tf.global_variables() + tf.local_variables()
var_by_name = dict([(v.name, v) for v in all_vars])
trainable_names = set([x.name for x in tf.trainable_variables()])
post_init_ops = []
def log_failure(name, reason):
if name in trainable_names:
msg = "This variable is trainable, so this is probably a fatal error."
else:
msg = "This variable is non-trainable. Ignore this warning if you know it's OK to leave it out-of-sync."
logger.warn("[ReplicatedTrainer] Do not know how to sync variable '{}' across GPUs. "
"Reason: {} ".format(name, reason) + msg)
for v in all_vars:
if not v.name.startswith('tower'):
continue
if v.name.startswith('tower0'):
logger.warn("[SyncMultiGPUReplicatedBuilder] variable "
"{} has prefix 'tower0', this is unexpected.".format(v.name))
continue # TODO some vars (EMA) may still startswith tower0
# in this trainer, the master name doesn't have the towerx/ prefix
log_failure(v.name, "Name should not have prefix 'tower0' in this trainer!")
continue # TODO some vars (EMA) may still startswith tower0
split_name = v.name.split('/')
prefix = split_name[0]
realname = '/'.join(split_name[1:])
if prefix in realname:
logger.error("[SyncMultiGPUReplicatedBuilder] variable "
"{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
log_failure(v.name, "Prefix {} appears multiple times in its name!".format(prefix))
continue
copy_from = var_by_name.get(realname)
if copy_from is not None:
post_init_ops.append(v.assign(copy_from.read_value()))
else:
logger.warn("[ReplicatedTrainer] Cannot find {} in the graph!".format(realname))
log_failure(v.name, "Cannot find {} in the graph!".format(realname))
logger.info(
"'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
......
......@@ -20,7 +20,9 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
""" A context where the current model is built in.
Since TF1.8, TensorFlow starts to introduce the same concept.
"""
def __init__(self, tower_name, is_training, index=0, vs_name=''):
"""
......
File mode changed from 100644 to 100755
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment