Commit 18b19d6d authored by Yuxin Wu's avatar Yuxin Wu

DoReFa uses ImageNetModel

parent 41759741
...@@ -96,7 +96,7 @@ class AtariPlayer(gym.Env): ...@@ -96,7 +96,7 @@ class AtariPlayer(gym.Env):
self.action_space = spaces.Discrete(len(self.actions)) self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width)) low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
self._restart_episode() self._restart_episode()
def get_action_meanings(self): def get_action_meanings(self):
......
...@@ -12,13 +12,12 @@ import sys ...@@ -12,13 +12,12 @@ import sys
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import prediction_incorrect from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.tfutils.varreplace import remap_variables from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel
from dorefa import get_dorefa, ternarize from dorefa import get_dorefa, ternarize
""" """
...@@ -59,15 +58,11 @@ TOTAL_BATCH_SIZE = 256 ...@@ -59,15 +58,11 @@ TOTAL_BATCH_SIZE = 256
BATCH_SIZE = None BATCH_SIZE = None
class Model(ModelDesc): class Model(ImageNetModel):
def inputs(self): weight_decay = 5e-6
return [tf.placeholder(tf.float32, [None, 224, 224, 3], 'input'), weight_decay_pattern = 'fc.*/W'
tf.placeholder(tf.int32, [None], 'label')]
def build_graph(self, image, label):
image = image / 255.0
image = tf.transpose(image, [0, 3, 1, 2])
def get_logits(self, image):
if BITW == 't': if BITW == 't':
fw, fa, fg = get_dorefa(32, 32, 32) fw, fa, fg = get_dorefa(32, 32, 32)
fw = ternarize fw = ternarize
...@@ -97,7 +92,7 @@ class Model(ModelDesc): ...@@ -97,7 +92,7 @@ class Model(ModelDesc):
argscope(BatchNorm, momentum=0.9, epsilon=1e-4), \ argscope(BatchNorm, momentum=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False): argscope(Conv2D, use_bias=False):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 96, 12, strides=4, padding='VALID') .Conv2D('conv0', 96, 12, strides=4, padding='VALID', use_bias=True)
.apply(activate) .apply(activate)
.Conv2D('conv1', 256, 5, padding='SAME', split=2) .Conv2D('conv1', 256, 5, padding='SAME', split=2)
.apply(fg) .apply(fg)
...@@ -132,24 +127,8 @@ class Model(ModelDesc): ...@@ -132,24 +127,8 @@ class Model(ModelDesc):
.BatchNorm('bnfc1') .BatchNorm('bnfc1')
.apply(nonlin) .apply(nonlin)
.FullyConnected('fct', 1000, use_bias=True)()) .FullyConnected('fct', 1000, use_bias=True)())
tf.nn.softmax(logits, name='output')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
# weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(5e-6), name='regularize_cost')
add_param_summary(('.*/W', ['histogram', 'rms'])) add_param_summary(('.*/W', ['histogram', 'rms']))
total_cost = tf.add_n([cost, wd_cost], name='cost') return logits
add_moving_summary(cost, wd_cost, total_cost)
return total_cost
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=2e-4, trainable=False) lr = tf.get_variable('learning_rate', initializer=2e-4, trainable=False)
......
...@@ -30,7 +30,7 @@ This Inception-BN script reaches 27% single-crop error after 300k steps with 6 G ...@@ -30,7 +30,7 @@ This Inception-BN script reaches 27% single-crop error after 300k steps with 6 G
This VGG16 script, when trained with 32x8 batch size, reaches the following This VGG16 script, when trained with 32x8 batch size, reaches the following
error rate after 100 epochs (30h with 8 P100s). This reproduces the VGG error rate after 100 epochs (30h with 8 P100s). This reproduces the VGG
experiements in the paper [Group Normalization](https://arxiv.org/abs/1803.08494). experiments in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
| No Normalization | Batch Normalization | Group Normalization | | No Normalization | Batch Normalization | Group Normalization |
|:---------------------------------|---------------------|--------------------:| |:---------------------------------|---------------------|--------------------:|
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment