Commit 9a4e6d9d authored by Yuxin Wu's avatar Yuxin Wu

use sparse softmax

parent d646972d
...@@ -19,7 +19,8 @@ from tensorpack.dataflow import * ...@@ -19,7 +19,8 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
""" """
CIFAR10 90% validation accuracy after 40k step. A small cifar10 convnet model.
90% validation accuracy after 40k step.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -62,8 +63,7 @@ class Model(ModelDesc): ...@@ -62,8 +63,7 @@ class Model(ModelDesc):
# fc will have activation summary by default. disable for the output layer # fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
y = one_hot(label, 10) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
......
...@@ -105,8 +105,7 @@ class Model(ModelDesc): ...@@ -105,8 +105,7 @@ class Model(ModelDesc):
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 10) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
......
...@@ -22,7 +22,7 @@ from IPython import embed; embed() ...@@ -22,7 +22,7 @@ from IPython import embed; embed()
""" """
MNIST ConvNet example. MNIST ConvNet example.
about 0.6% validation error after 50 epochs. about 0.6% validation error after 30 epochs.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -58,8 +58,7 @@ class Model(ModelDesc): ...@@ -58,8 +58,7 @@ class Model(ModelDesc):
logits = FullyConnected('fc1', l, out_dim=10, nl=tf.identity) logits = FullyConnected('fc1', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='prob') prob = tf.nn.softmax(logits, name='prob')
y = one_hot(label, 10) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
...@@ -97,7 +96,7 @@ def get_config(): ...@@ -97,7 +96,7 @@ def get_config():
learning_rate=1e-3, learning_rate=1e-3,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 10, decay_steps=dataset_train.size() * 10,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.3, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
......
...@@ -48,8 +48,7 @@ class Model(ModelDesc): ...@@ -48,8 +48,7 @@ class Model(ModelDesc):
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 10) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment