Commit dcbd4696 authored by Yuxin Wu's avatar Yuxin Wu

bnrelu, classificationerror

parent fec3a4a5
...@@ -20,8 +20,6 @@ from tensorpack.dataflow import imgaug ...@@ -20,8 +20,6 @@ from tensorpack.dataflow import imgaug
""" """
CIFAR10 90% validation accuracy after 70k step. CIFAR10 90% validation accuracy after 70k step.
91% validation accuracy after 36k step with 3 GPU.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -46,15 +44,15 @@ class Model(ModelDesc): ...@@ -46,15 +44,15 @@ class Model(ModelDesc):
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3) l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training)) l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME') l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3) l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training)) l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME') l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID') l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID')
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training)) l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = FullyConnected('fc0', l, 1024 + 512, l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob) l = tf.nn.dropout(l, keep_prob)
...@@ -69,7 +67,7 @@ class Model(ModelDesc): ...@@ -69,7 +67,7 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label) wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
...@@ -125,7 +123,7 @@ def get_config(): ...@@ -125,7 +123,7 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-2, learning_rate=1e-2,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 15, decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 20,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -135,7 +133,7 @@ def get_config(): ...@@ -135,7 +133,7 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ClassificationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
...@@ -155,6 +153,8 @@ if __name__ == '__main__': ...@@ -155,6 +153,8 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default(): with tf.Graph().as_default():
config = get_config() config = get_config()
......
...@@ -167,7 +167,7 @@ def get_config(): ...@@ -167,7 +167,7 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ClassificationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]), ]),
......
...@@ -62,7 +62,7 @@ class Model(ModelDesc): ...@@ -62,7 +62,7 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label) wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
...@@ -106,7 +106,7 @@ def get_config(): ...@@ -106,7 +106,7 @@ def get_config():
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']), ValidationStatPrinter(dataset_test, ['cost:0']),
ValidationError(dataset_test, prefix='validation'), ClassificationError(dataset_test, prefix='validation'),
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
......
...@@ -53,7 +53,7 @@ class Model(ModelDesc): ...@@ -53,7 +53,7 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label) wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
...@@ -110,7 +110,7 @@ def get_config(): ...@@ -110,7 +110,7 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(test, prefix='test'), ClassificationError(test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
......
...@@ -12,7 +12,7 @@ from ..utils.stat import * ...@@ -12,7 +12,7 @@ from ..utils.stat import *
from ..tfutils.summary import * from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallbackType from .base import PeriodicCallback, Callback, TestCallbackType
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter'] __all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback): class ValidationCallback(PeriodicCallback):
""" """
...@@ -100,8 +100,7 @@ class ValidationStatPrinter(ValidationCallback): ...@@ -100,8 +100,7 @@ class ValidationStatPrinter(ValidationCallback):
'{}_{}'.format(self.prefix, name), stat), self.global_step) '{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat) self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)
class ClassificationError(ValidationCallback):
class ValidationError(ValidationCallback):
""" """
Validate the accuracy from a `wrong` variable Validate the accuracy from a `wrong` variable
...@@ -119,7 +118,7 @@ class ValidationError(ValidationCallback): ...@@ -119,7 +118,7 @@ class ValidationError(ValidationCallback):
:param ds: a batched `DataFlow` instance :param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable :param wrong_var_name: name of the `wrong` variable
""" """
super(ValidationError, self).__init__(ds, prefix, period) super(ClassificationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name self.wrong_var_name = wrong_var_name
def _find_output_vars(self): def _find_output_vars(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment