Commit a2c36b3d authored by Yuxin Wu's avatar Yuxin Wu

misc docs update; use virtual_batch_size only for TF>=1.5 (fix #737)

parent 03f18976
...@@ -41,7 +41,9 @@ Accuracy: ...@@ -41,7 +41,9 @@ Accuracy:
With (W,A,G)=(1,2,6), 47.6% error With (W,A,G)=(1,2,6), 47.6% error
With (W,A,G)=(1,2,4), 58.4% error With (W,A,G)=(1,2,4), 58.4% error
Don't train with >4 GPUs because the batch size will be different. Training with 2 or 8 GPUs is supported but the result may get slightly
different, due to limited per-GPU batch size.
You may want to adjust total batch size and learning rate accordingly.
Speed: Speed:
About 11 iteration/s on 4 P100s. (Each epoch is set to 10000 iterations) About 11 iteration/s on 4 P100s. (Each epoch is set to 10000 iterations)
......
...@@ -15,7 +15,7 @@ A small convnet model for Cifar10 or Cifar100 dataset. ...@@ -15,7 +15,7 @@ A small convnet model for Cifar10 or Cifar100 dataset.
Cifar10 trained on 1 GPU: Cifar10 trained on 1 GPU:
91% accuracy after 50k iterations. 91% accuracy after 50k iterations.
70 itr/s on P100 79 itr/s on P100
Not a good model for Cifar100, just for demonstration. Not a good model for Cifar100, just for demonstration.
""" """
......
...@@ -89,8 +89,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -89,8 +89,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if training is None: if training is None:
training = ctx.is_training training = ctx.is_training
training = bool(training) training = bool(training)
TF_version = get_tf_version_number()
if not training and ctx.is_training: if not training and ctx.is_training:
assert get_tf_version_number() >= 1.4, \ assert TF_version >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \ "Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 " "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower if ctx.is_main_training_tower: # only warn in first tower
...@@ -102,15 +103,26 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -102,15 +103,26 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
with rename_get_variable( with rename_get_variable(
{'moving_mean': 'mean/EMA', {'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA'}): 'moving_variance': 'variance/EMA'}):
layer = tf.layers.BatchNormalization( if TF_version >= 1.5:
axis=axis, layer = tf.layers.BatchNormalization(
momentum=momentum, epsilon=epsilon, axis=axis,
center=center, scale=scale, momentum=momentum, epsilon=epsilon,
beta_initializer=beta_initializer, center=center, scale=scale,
gamma_initializer=gamma_initializer, beta_initializer=beta_initializer,
virtual_batch_size=virtual_batch_size, gamma_initializer=gamma_initializer,
fused=True virtual_batch_size=virtual_batch_size,
) fused=True
)
else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!"
layer = tf.layers.BatchNormalization(
axis=axis,
momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
fused=True
)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
# maintain EMA only on one GPU is OK, even in replicated mode. # maintain EMA only on one GPU is OK, even in replicated mode.
......
...@@ -6,8 +6,9 @@ GIT_ARG="--git-dir ../.git --work-tree .." ...@@ -6,8 +6,9 @@ GIT_ARG="--git-dir ../.git --work-tree .."
# find out modified python files, so that we ignored unstaged files # find out modified python files, so that we ignored unstaged files
# exclude ../docs # exclude ../docs
MOD=$(git $GIT_ARG status -s | grep -E '\.py$' \ MOD=$(git $GIT_ARG status -s \
| grep -E '^\b+M\b+|^A' | cut -c 4- | grep -v '../docs') | grep -E '\.py$' | grep -v '../docs' \
| grep -E '^ *M|^ *A' | cut -c 4- )
if [[ -n $MOD ]]; then if [[ -n $MOD ]]; then
flake8 $MOD flake8 $MOD
fi fi
...@@ -4,6 +4,7 @@ ignore = E265,E741,E742,E743 ...@@ -4,6 +4,7 @@ ignore = E265,E741,E742,E743
exclude = .git, exclude = .git,
__init__.py, __init__.py,
setup.py, setup.py,
tensorpack/train/eager.py,
docs, docs,
examples, examples,
docs/conf.py docs/conf.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment