Commit 18b19d6d authored by Yuxin Wu's avatar Yuxin Wu

DoReFa uses ImageNetModel

parent 41759741
......@@ -96,7 +96,7 @@ class AtariPlayer(gym.Env):
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width))
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
self._restart_episode()
def get_action_meanings(self):
......
......@@ -12,13 +12,12 @@ import sys
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import prediction_incorrect
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel
from dorefa import get_dorefa, ternarize
"""
......@@ -59,15 +58,11 @@ TOTAL_BATCH_SIZE = 256
BATCH_SIZE = None
class Model(ModelDesc):
def inputs(self):
return [tf.placeholder(tf.float32, [None, 224, 224, 3], 'input'),
tf.placeholder(tf.int32, [None], 'label')]
def build_graph(self, image, label):
image = image / 255.0
image = tf.transpose(image, [0, 3, 1, 2])
class Model(ImageNetModel):
weight_decay = 5e-6
weight_decay_pattern = 'fc.*/W'
def get_logits(self, image):
if BITW == 't':
fw, fa, fg = get_dorefa(32, 32, 32)
fw = ternarize
......@@ -97,7 +92,7 @@ class Model(ModelDesc):
argscope(BatchNorm, momentum=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False):
logits = (LinearWrap(image)
.Conv2D('conv0', 96, 12, strides=4, padding='VALID')
.Conv2D('conv0', 96, 12, strides=4, padding='VALID', use_bias=True)
.apply(activate)
.Conv2D('conv1', 256, 5, padding='SAME', split=2)
.apply(fg)
......@@ -132,24 +127,8 @@ class Model(ModelDesc):
.BatchNorm('bnfc1')
.apply(nonlin)
.FullyConnected('fct', 1000, use_bias=True)())
tf.nn.softmax(logits, name='output')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
# weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(5e-6), name='regularize_cost')
add_param_summary(('.*/W', ['histogram', 'rms']))
total_cost = tf.add_n([cost, wd_cost], name='cost')
add_moving_summary(cost, wd_cost, total_cost)
return total_cost
return logits
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=2e-4, trainable=False)
......
......@@ -30,7 +30,7 @@ This Inception-BN script reaches 27% single-crop error after 300k steps with 6 G
This VGG16 script, when trained with 32x8 batch size, reaches the following
error rate after 100 epochs (30h with 8 P100s). This reproduces the VGG
experiements in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
experiments in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
| No Normalization | Batch Normalization | Group Normalization |
|:---------------------------------|---------------------|--------------------:|
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment