Commit 3facd518 authored by Yuxin Wu's avatar Yuxin Wu

some small fix

parent 5dee7231
......@@ -8,6 +8,7 @@ More results to come.
| Model | Top 5 Error | Top 1 Error |
|:-------------------|-------------|------------:|
| ResNet 18 | 10.67% | 29.50% |
| ResNet 34 | 8.66% | 26.45% |
| ResNet 50 | 7.13% | 24.12% |
## load-resnet.py
......
......@@ -250,7 +250,7 @@ def main():
fin.close()
# parse column format
nr_column = len(all_inputs[0].rstrip().split())
nr_column = len(all_inputs[0].rstrip().split(args.delimeter))
if args.column is None:
column = ['y'] * nr_column
else:
......
......@@ -2,6 +2,7 @@
# File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from __future__ import print_function
import multiprocessing as mp
from threading import Thread
import itertools
......@@ -158,16 +159,12 @@ class PrefetchDataZMQ(ProxyDataFlow):
def __del__(self):
# on exit, logger may not be functional anymore
try:
logger.info("Prefetch process exiting...")
except:
pass
if not self.context.closed:
self.context.destroy(0)
for x in self.procs:
x.terminate()
try:
logger.info("Prefetch process exited.")
print("Prefetch process exited.")
except:
pass
......
......@@ -4,16 +4,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from copy import copy
import re
from ..tfutils.tower import get_current_tower_context
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from ..utils import logger
from ._common import layer_register
__all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# TF batch_norm only works for 4D tensor right now: #804
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
......@@ -25,15 +25,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
`Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
Notes:
* Whole-population mean/variance is calculated by a running-average mean/variance.
* Epsilon for variance is set to 1e-5, as is `torch/nn <https://github.com/torch/nn/blob/master/BatchNormalization.lua>`_.
:param input: a NHWC or NC tensor
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average.
Default to True in training and False in predicting.
:param decay: decay rate. default to 0.999.
Default to True in training and False in inference.
:param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5.
"""
......@@ -70,8 +65,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
if ctx.is_main_training_tower:
# inside main training tower
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
add_model_variable(ema_mean)
add_model_variable(ema_var)
else:
if ctx.is_main_tower:
# not training, but main tower. need to create the vars
......@@ -98,7 +93,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
mul = tf.select(tf.equal(batch, 1.0), 1.0, batch / (batch - 1))
batch_var = batch_var * mul # use unbiased variance estimator in training
return tf.nn.batch_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
x, batch_mean, batch_var, beta, gamma, epsilon, 'output')
else:
return tf.nn.batch_normalization(
x, ema_mean, ema_var, beta, gamma, epsilon, 'bn')
x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
......@@ -107,9 +107,9 @@ class CheckGradient(MapGradient):
super(CheckGradient, self).__init__(self._mapper)
def _mapper(self, grad, var):
# this is very slow...
# this is very slow.... see #3649
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad = tf.check_numerics(grad, 'CheckGradient')
grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name)
return grad
class ScaleGradient(MapGradient):
......
......@@ -76,7 +76,9 @@ def dump_session_params(path):
npy format, loadable by ParamRestore
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
# TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
result = {}
for v in var:
name = get_savename_from_varname(v.name)
......@@ -102,7 +104,7 @@ def is_training_name(name):
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ?
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names()
name = get_op_tensor_name(name)[0]
if name.endswith('/Adam') or name.endswith('/Adam_1'):
......
......@@ -11,9 +11,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
INPUT_VARS_KEY = 'INPUT_VARIABLES'
# variables that need to be saved for inference, apart from trainable variables
EXTRA_SAVE_VARS_KEY = 'EXTRA_SAVE_VARIABLES'
import tensorflow as tf
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment