Commit 34f0dd6d authored by Yuxin Wu's avatar Yuxin Wu

docs update and fix bug in ResNet-SE

parent f2b0f1be
......@@ -47,3 +47,5 @@ are likely to have too much variance. You can:
Besides TensorFlow summaries,
a callback is free to log any other types of data to the monitor backend,
anytime after the training has started.
As long as the type of data is supported, it will be logged by each monitor.
In other words, tensorboard can show not only summaries in the graph, but also your custom data.
......@@ -11,7 +11,7 @@ Models can be [downloaded here](https://goo.gl/6XjK9V).
| ResNet18 | 10.55% | 29.73% |
| ResNet34 | 8.51% | 26.50% |
| ResNet50 | 7.24% | 23.91% |
| ResNet50-SE | 6.42% | 22.94% |
| ResNet50-SE | TRAINING | TRAINING |
| ResNet101 | 6.26% | 22.53% |
To train, just run:
......
......@@ -48,7 +48,7 @@ class Model(ModelDesc):
l = Conv2D('conv3', l, ch_out * 4, 1)
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.identity)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
......
......@@ -32,7 +32,7 @@ class NewSessionCreator(tf.train.SessionCreator):
sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logger.info("Global variables initialized.")
logger.info("Global and local variables initialized.")
return sess
......
......@@ -118,7 +118,7 @@ class Canvas(object):
self.nr_col = nr_col
if border is None:
border = int(0.1 * min(ph, pw))
border = int(0.05 * min(ph, pw))
self.border = border
if isinstance(bgcolor, int):
......@@ -169,9 +169,9 @@ def stack_patches(
Args:
patch_list(list[ndarray] or ndarray): NHW or NHWC images in [0,255].
nr_row(int), nr_col(int): rows and cols of the grid.
``nr_col * nr_row`` must be equal to ``len(patch_list)``.
``nr_col * nr_row`` must be no less than ``len(patch_list)``.
border(int): border length between images.
Defaults to ``0.1 * min(patch_width, patch_height)``.
Defaults to ``0.05 * min(patch_width, patch_height)``.
pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width.
This option allows stacking patches of different shapes together.
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
......@@ -235,7 +235,7 @@ def gen_stack_patches(patch_list,
ph, pw = patch_list.shape[1:3]
if border is None:
border = int(0.1 * min(ph, pw))
border = int(0.05 * min(ph, pw))
if nr_row is None:
nr_row = int(max_height / (ph + border))
if nr_col is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment