Commit 34f0dd6d authored by Yuxin Wu's avatar Yuxin Wu

docs update and fix bug in ResNet-SE

parent f2b0f1be
...@@ -47,3 +47,5 @@ are likely to have too much variance. You can: ...@@ -47,3 +47,5 @@ are likely to have too much variance. You can:
Besides TensorFlow summaries, Besides TensorFlow summaries,
a callback is free to log any other types of data to the monitor backend, a callback is free to log any other types of data to the monitor backend,
anytime after the training has started. anytime after the training has started.
As long as the type of data is supported, it will be logged by each monitor.
In other words, tensorboard can show not only summaries in the graph, but also your custom data.
...@@ -11,7 +11,7 @@ Models can be [downloaded here](https://goo.gl/6XjK9V). ...@@ -11,7 +11,7 @@ Models can be [downloaded here](https://goo.gl/6XjK9V).
| ResNet18 | 10.55% | 29.73% | | ResNet18 | 10.55% | 29.73% |
| ResNet34 | 8.51% | 26.50% | | ResNet34 | 8.51% | 26.50% |
| ResNet50 | 7.24% | 23.91% | | ResNet50 | 7.24% | 23.91% |
| ResNet50-SE | 6.42% | 22.94% | | ResNet50-SE | TRAINING | TRAINING |
| ResNet101 | 6.26% | 22.53% | | ResNet101 | 6.26% | 22.53% |
To train, just run: To train, just run:
......
...@@ -48,7 +48,7 @@ class Model(ModelDesc): ...@@ -48,7 +48,7 @@ class Model(ModelDesc):
l = Conv2D('conv3', l, ch_out * 4, 1) l = Conv2D('conv3', l, ch_out * 4, 1)
squeeze = GlobalAvgPooling('gap', l) squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.identity) squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid) squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1]) l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride) return l + resnet_shortcut(shortcut, ch_out * 4, stride)
......
...@@ -32,7 +32,7 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -32,7 +32,7 @@ class NewSessionCreator(tf.train.SessionCreator):
sess = tf.Session(target=self.target, graph=self.graph, config=self.config) sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) sess.run(tf.local_variables_initializer())
logger.info("Global variables initialized.") logger.info("Global and local variables initialized.")
return sess return sess
......
...@@ -118,7 +118,7 @@ class Canvas(object): ...@@ -118,7 +118,7 @@ class Canvas(object):
self.nr_col = nr_col self.nr_col = nr_col
if border is None: if border is None:
border = int(0.1 * min(ph, pw)) border = int(0.05 * min(ph, pw))
self.border = border self.border = border
if isinstance(bgcolor, int): if isinstance(bgcolor, int):
...@@ -169,9 +169,9 @@ def stack_patches( ...@@ -169,9 +169,9 @@ def stack_patches(
Args: Args:
patch_list(list[ndarray] or ndarray): NHW or NHWC images in [0,255]. patch_list(list[ndarray] or ndarray): NHW or NHWC images in [0,255].
nr_row(int), nr_col(int): rows and cols of the grid. nr_row(int), nr_col(int): rows and cols of the grid.
``nr_col * nr_row`` must be equal to ``len(patch_list)``. ``nr_col * nr_row`` must be no less than ``len(patch_list)``.
border(int): border length between images. border(int): border length between images.
Defaults to ``0.1 * min(patch_width, patch_height)``. Defaults to ``0.05 * min(patch_width, patch_height)``.
pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width. pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width.
This option allows stacking patches of different shapes together. This option allows stacking patches of different shapes together.
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
...@@ -235,7 +235,7 @@ def gen_stack_patches(patch_list, ...@@ -235,7 +235,7 @@ def gen_stack_patches(patch_list,
ph, pw = patch_list.shape[1:3] ph, pw = patch_list.shape[1:3]
if border is None: if border is None:
border = int(0.1 * min(ph, pw)) border = int(0.05 * min(ph, pw))
if nr_row is None: if nr_row is None:
nr_row = int(max_height / (ph + border)) nr_row = int(max_height / (ph + border))
if nr_col is None: if nr_col is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment