Commit a041f5a9 authored by Yuxin Wu's avatar Yuxin Wu

example docs and improvements

parent 64bda846
...@@ -41,7 +41,7 @@ Multi-GPU training is off-the-shelf by simply switching the trainer. ...@@ -41,7 +41,7 @@ Multi-GPU training is off-the-shelf by simply switching the trainer.
## Dependencies: ## Dependencies:
+ Python 2 or 3 + Python 2 or 3
+ TensorFlow >= 0.8 + TensorFlow >= 0.10
+ Python bindings for OpenCV + Python bindings for OpenCV
+ other requirements: + other requirements:
``` ```
......
...@@ -27,6 +27,7 @@ To start training: ...@@ -27,6 +27,7 @@ To start training:
```bash ```bash
./hed.py --load vgg16.npy ./hed.py --load vgg16.npy
``` ```
It takes about 100k steps (~10 hour on a TitanX) to reach a reasonable performance.
To inference (produce a heatmap at each level at out*.png): To inference (produce a heatmap at each level at out*.png):
```bash ```bash
......
...@@ -21,8 +21,8 @@ See "Rethinking the Inception Architecture for Computer Vision", arxiv:1512.0056 ...@@ -21,8 +21,8 @@ See "Rethinking the Inception Architecture for Computer Vision", arxiv:1512.0056
This config follows the official inceptionv3 setup (https://github.com/tensorflow/models/tree/master/inception/inception) This config follows the official inceptionv3 setup (https://github.com/tensorflow/models/tree/master/inception/inception)
with much much fewer lines of code. with much much fewer lines of code.
It reaches 74.5% single-crop validation accuracy, slightly better than the official code, It reaches 74% single-crop validation accuracy,
and has the same running speed as well. and has the same running speed as the official code.
The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512. The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512.
With 8 TitanX it runs about 0.45 it/s. With 8 TitanX it runs about 0.45 it/s.
""" """
......
...@@ -51,11 +51,9 @@ class Model(ModelDesc): ...@@ -51,11 +51,9 @@ class Model(ModelDesc):
.MaxPooling('pool2', 3, stride=2, padding='SAME') \ .MaxPooling('pool2', 3, stride=2, padding='SAME') \
.Conv2D('conv3.1', out_channel=128, padding='VALID') \ .Conv2D('conv3.1', out_channel=128, padding='VALID') \
.Conv2D('conv3.2', out_channel=128, padding='VALID') \ .Conv2D('conv3.2', out_channel=128, padding='VALID') \
.FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu, .FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu) \
b_init=tf.constant_initializer(0.1)) \
.tf.nn.dropout(keep_prob) \ .tf.nn.dropout(keep_prob) \
.FullyConnected('fc1', 512, nl=tf.nn.relu, .FullyConnected('fc1', 512, nl=tf.nn.relu) \
b_init=tf.constant_initializer(0.1)) \
.FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)() .FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)()
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
......
...@@ -14,9 +14,10 @@ from tensorpack.tfutils.summary import * ...@@ -14,9 +14,10 @@ from tensorpack.tfutils.summary import *
""" """
A very small SVHN convnet model (only 0.8m parameters). A very small SVHN convnet model (only 0.8m parameters).
About 3.0% validation error after 70 epoch. 2.5% after 130 epoch. About 2.3% validation error after 70 epochs. 2.15% after 150 epochs.
Each epoch is set to 4721 iterations. The speed is about 44 it/s on a Tesla M40 Each epoch iterates over the whole training set (4721 iterations).
Speed is about 43 it/s on TitanX.
""" """
class Model(ModelDesc): class Model(ModelDesc):
...@@ -29,7 +30,7 @@ class Model(ModelDesc): ...@@ -29,7 +30,7 @@ class Model(ModelDesc):
image = image / 128.0 - 1 image = image / 128.0 - 1
with argscope(Conv2D, nl=tf.nn.relu): with argscope(Conv2D, nl=BNReLU, use_bias=False):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv1', 24, 5, padding='VALID') .Conv2D('conv1', 24, 5, padding='VALID')
.MaxPooling('pool1', 2, padding='SAME') .MaxPooling('pool1', 2, padding='SAME')
...@@ -62,7 +63,7 @@ def get_data(): ...@@ -62,7 +63,7 @@ def get_data():
d1 = dataset.SVHNDigit('train') d1 = dataset.SVHNDigit('train')
d2 = dataset.SVHNDigit('extra') d2 = dataset.SVHNDigit('extra')
data_train = RandomMixData([d1, d2]) data_train = RandomMixData([d1, d2])
data_test = dataset.SVHNDigit('test') data_test = dataset.SVHNDigit('test', shuffle=False)
augmentors = [ augmentors = [
imgaug.Resize((40, 40)), imgaug.Resize((40, 40)),
......
...@@ -24,7 +24,7 @@ with tf.Graph().as_default() as G: ...@@ -24,7 +24,7 @@ with tf.Graph().as_default() as G:
if args.config: if args.config:
MODEL = imp.load_source('config_script', args.config).Model MODEL = imp.load_source('config_script', args.config).Model
M = MODEL() M = MODEL()
M.build_graph(M.get_input_vars(), is_training=False) M.build_graph(M.get_input_vars())
else: else:
M = ModelFromMetaGraph(args.meta) M = ModelFromMetaGraph(args.meta)
......
...@@ -60,8 +60,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -60,8 +60,7 @@ def Conv2D(x, out_channel, kernel_shape,
for i, k in zip(inputs, kernels)] for i, k in zip(inputs, kernels)]
conv = tf.concat(3, outputs) conv = tf.concat(3, outputs)
if nl is None: if nl is None:
logger.warn("[DEPRECATED] Default nonlinearity for Conv2D and FullyConnected will be deprecated.") logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
logger.warn("[DEPRECATED] Please use argscope instead.")
nl = tf.nn.relu nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output') return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
...@@ -40,7 +40,6 @@ def FullyConnected(x, out_dim, ...@@ -40,7 +40,6 @@ def FullyConnected(x, out_dim,
b = tf.get_variable('b', [out_dim], initializer=b_init) b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W) prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None: if nl is None:
logger.warn("[DEPRECATED] Default nonlinearity for Conv2D and FullyConnected will be deprecated.") logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
logger.warn("[DEPRECATED] Please use argscope instead.")
nl = tf.nn.relu nl = tf.nn.relu
return nl(prod, name='output') return nl(prod, name='output')
...@@ -50,13 +50,25 @@ class TrainConfig(object): ...@@ -50,13 +50,25 @@ class TrainConfig(object):
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', JustCurrentSession()) self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.step_per_epoch = kwargs.pop('step_per_epoch', None)
if self.step_per_epoch is None:
try:
self.step_per_epoch = self.dataset.size()
except NotImplementedError:
logger.exception("You must set `step_per_epoch` if dataset.size() is not implemented.")
else:
self.step_per_epoch = int(self.step_per_epoch)
self.starting_epoch = int(kwargs.pop('starting_epoch', 1)) self.starting_epoch = int(kwargs.pop('starting_epoch', 1))
self.max_epoch = int(kwargs.pop('max_epoch', 99999)) self.max_epoch = int(kwargs.pop('max_epoch', 99999))
assert self.step_per_epoch >= 0 and self.max_epoch > 0 assert self.step_per_epoch >= 0 and self.max_epoch > 0
if 'nr_tower' in kwargs or 'tower' in kwargs: if 'nr_tower' in kwargs:
self.set_tower(**kwargs) assert 'tower' not in kwargs, "Cannot set both nr_tower and tower in TrainConfig!"
self.nr_tower = kwargs.pop('nr_tower')
elif 'tower' in kwargs:
self.tower = kwargs.pop('tower')
else: else:
self.tower = [0] self.tower = [0]
...@@ -64,8 +76,8 @@ class TrainConfig(object): ...@@ -64,8 +76,8 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None): def set_tower(self, nr_tower=None, tower=None):
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
# this is a deprecated function # this is a deprecated function
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!" assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower: if nr_tower:
tower = list(range(nr_tower)) tower = list(range(nr_tower))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment