Commit a041f5a9 authored by Yuxin Wu's avatar Yuxin Wu

example docs and improvements

parent 64bda846
......@@ -41,7 +41,7 @@ Multi-GPU training is off-the-shelf by simply switching the trainer.
## Dependencies:
+ Python 2 or 3
+ TensorFlow >= 0.8
+ TensorFlow >= 0.10
+ Python bindings for OpenCV
+ other requirements:
```
......
......@@ -27,6 +27,7 @@ To start training:
```bash
./hed.py --load vgg16.npy
```
It takes about 100k steps (~10 hour on a TitanX) to reach a reasonable performance.
To inference (produce a heatmap at each level at out*.png):
```bash
......
......@@ -21,8 +21,8 @@ See "Rethinking the Inception Architecture for Computer Vision", arxiv:1512.0056
This config follows the official inceptionv3 setup (https://github.com/tensorflow/models/tree/master/inception/inception)
with much much fewer lines of code.
It reaches 74.5% single-crop validation accuracy, slightly better than the official code,
and has the same running speed as well.
It reaches 74% single-crop validation accuracy,
and has the same running speed as the official code.
The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512.
With 8 TitanX it runs about 0.45 it/s.
"""
......
......@@ -51,11 +51,9 @@ class Model(ModelDesc):
.MaxPooling('pool2', 3, stride=2, padding='SAME') \
.Conv2D('conv3.1', out_channel=128, padding='VALID') \
.Conv2D('conv3.2', out_channel=128, padding='VALID') \
.FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu,
b_init=tf.constant_initializer(0.1)) \
.FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu) \
.tf.nn.dropout(keep_prob) \
.FullyConnected('fc1', 512, nl=tf.nn.relu,
b_init=tf.constant_initializer(0.1)) \
.FullyConnected('fc1', 512, nl=tf.nn.relu) \
.FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)()
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
......
......@@ -14,9 +14,10 @@ from tensorpack.tfutils.summary import *
"""
A very small SVHN convnet model (only 0.8m parameters).
About 3.0% validation error after 70 epoch. 2.5% after 130 epoch.
About 2.3% validation error after 70 epochs. 2.15% after 150 epochs.
Each epoch is set to 4721 iterations. The speed is about 44 it/s on a Tesla M40
Each epoch iterates over the whole training set (4721 iterations).
Speed is about 43 it/s on TitanX.
"""
class Model(ModelDesc):
......@@ -29,7 +30,7 @@ class Model(ModelDesc):
image = image / 128.0 - 1
with argscope(Conv2D, nl=tf.nn.relu):
with argscope(Conv2D, nl=BNReLU, use_bias=False):
logits = (LinearWrap(image)
.Conv2D('conv1', 24, 5, padding='VALID')
.MaxPooling('pool1', 2, padding='SAME')
......@@ -62,7 +63,7 @@ def get_data():
d1 = dataset.SVHNDigit('train')
d2 = dataset.SVHNDigit('extra')
data_train = RandomMixData([d1, d2])
data_test = dataset.SVHNDigit('test')
data_test = dataset.SVHNDigit('test', shuffle=False)
augmentors = [
imgaug.Resize((40, 40)),
......
......@@ -24,7 +24,7 @@ with tf.Graph().as_default() as G:
if args.config:
MODEL = imp.load_source('config_script', args.config).Model
M = MODEL()
M.build_graph(M.get_input_vars(), is_training=False)
M.build_graph(M.get_input_vars())
else:
M = ModelFromMetaGraph(args.meta)
......
......@@ -60,8 +60,7 @@ def Conv2D(x, out_channel, kernel_shape,
for i, k in zip(inputs, kernels)]
conv = tf.concat(3, outputs)
if nl is None:
logger.warn("[DEPRECATED] Default nonlinearity for Conv2D and FullyConnected will be deprecated.")
logger.warn("[DEPRECATED] Please use argscope instead.")
logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
......@@ -40,7 +40,6 @@ def FullyConnected(x, out_dim,
b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None:
logger.warn("[DEPRECATED] Default nonlinearity for Conv2D and FullyConnected will be deprecated.")
logger.warn("[DEPRECATED] Please use argscope instead.")
logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
nl = tf.nn.relu
return nl(prod, name='output')
......@@ -50,13 +50,25 @@ class TrainConfig(object):
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.step_per_epoch = kwargs.pop('step_per_epoch', None)
if self.step_per_epoch is None:
try:
self.step_per_epoch = self.dataset.size()
except NotImplementedError:
logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.")
else:
self.step_per_epoch = int(self.step_per_epoch)
self.starting_epoch = int(kwargs.pop('starting_epoch', 1))
self.max_epoch = int(kwargs.pop('max_epoch', 99999))
assert self.step_per_epoch >= 0 and self.max_epoch > 0
if 'nr_tower' in kwargs or 'tower' in kwargs:
self.set_tower(**kwargs)
if 'nr_tower' in kwargs:
assert 'tower' not in kwargs, "Cannot set both nr_tower and tower in TrainConfig!"
self.nr_tower = kwargs.pop('nr_tower')
elif 'tower' in kwargs:
self.tower = kwargs.pop('tower')
else:
self.tower = [0]
......@@ -64,8 +76,8 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
# this is a deprecated function
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower:
tower = list(range(nr_tower))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment