Commit dcbd4696 authored by Yuxin Wu's avatar Yuxin Wu

bnrelu, classificationerror

parent fec3a4a5
......@@ -20,8 +20,6 @@ from tensorpack.dataflow import imgaug
"""
CIFAR10 90% validation accuracy after 70k step.
91% validation accuracy after 36k step with 3 GPU.
"""
BATCH_SIZE = 128
......@@ -46,15 +44,15 @@ class Model(ModelDesc):
image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training))
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training))
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID')
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training))
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob)
......@@ -69,7 +67,7 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
......@@ -125,7 +123,7 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 15,
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 20,
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
......@@ -135,7 +133,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ClassificationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
model=Model(),
......@@ -155,6 +153,8 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default():
config = get_config()
......
......@@ -167,7 +167,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ClassificationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]),
......
......@@ -62,7 +62,7 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
......@@ -106,7 +106,7 @@ def get_config():
StatPrinter(),
PeriodicSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ValidationError(dataset_test, prefix='validation'),
ClassificationError(dataset_test, prefix='validation'),
]),
session_config=sess_config,
model=Model(),
......
......@@ -53,7 +53,7 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
......@@ -110,7 +110,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(test, prefix='test'),
ClassificationError(test, prefix='test'),
]),
session_config=sess_config,
model=Model(),
......
......@@ -12,7 +12,7 @@ from ..utils.stat import *
from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallbackType
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
__all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback):
"""
......@@ -100,8 +100,7 @@ class ValidationStatPrinter(ValidationCallback):
'{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)
class ValidationError(ValidationCallback):
class ClassificationError(ValidationCallback):
"""
Validate the accuracy from a `wrong` variable
......@@ -119,7 +118,7 @@ class ValidationError(ValidationCallback):
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super(ValidationError, self).__init__(ds, prefix, period)
super(ClassificationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name
def _find_output_vars(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment