Commit fc9e45b0 authored by Yuxin Wu's avatar Yuxin Wu

fix bn bug

parent 87569ca6
...@@ -43,4 +43,4 @@ To visualize the agent: ...@@ -43,4 +43,4 @@ To visualize the agent:
./DQN.py --rom breakout.bin --task play --load pretrained.model ./DQN.py --rom breakout.bin --task play --load pretrained.model
``` ```
A3C code will be released at the end of August. A3C code will be released in about a week.
...@@ -32,15 +32,16 @@ class Model(ModelDesc): ...@@ -32,15 +32,16 @@ class Model(ModelDesc):
InputVar(tf.int32, [None], 'label') InputVar(tf.int32, [None], 'label')
] ]
def _build_graph(self, input_vars, is_training): def _build_graph(self, input_vars):
image, label = input_vars image, label = input_vars
is_training = get_current_tower_context().is_training
keep_prob = tf.constant(0.5 if is_training else 1.0) keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training: if is_training:
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
with argscope(Conv2D, nl=BNReLU(is_training), use_bias=False, kernel_shape=3): with argscope(Conv2D, nl=BNReLU(), use_bias=False, kernel_shape=3):
logits = LinearWrap(image) \ logits = LinearWrap(image) \
.Conv2D('conv1.1', out_channel=64) \ .Conv2D('conv1.1', out_channel=64) \
.Conv2D('conv1.2', out_channel=64) \ .Conv2D('conv1.2', out_channel=64) \
......
...@@ -22,8 +22,8 @@ class Model(ModelDesc): ...@@ -22,8 +22,8 @@ class Model(ModelDesc):
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputVar(tf.int32, (None,), 'label') ] InputVar(tf.int32, (None,), 'label') ]
def _build_graph(self, input_vars, is_training): def _build_graph(self, input_vars):
is_training = bool(is_training) is_training = get_current_tower_context().is_training
keep_prob = tf.constant(0.5 if is_training else 1.0) keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = input_vars image, label = input_vars
......
...@@ -72,20 +72,20 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -72,20 +72,20 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var) tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else: else:
assert not use_local_stat assert not use_local_stat
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
if ctx.is_main_tower: if ctx.is_main_tower:
# not training, but main tower. need to create the vars # not training, but main tower. need to create the vars
with tf.name_scope(None): with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else: else:
# use statistics in another tower # use statistics in another tower
G = tf.get_default_graph() G = tf.get_default_graph()
# figure out the var name # figure out the var name
mean_var_name = ema.average_name(batch_mean) + ':0' with tf.name_scope(None):
var_var_name = ema.average_name(batch_var) + ':0' ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean) + ':0'
var_var_name = ema.average_name(batch_var) + ':0'
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name) ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name)
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name) ema_var = ctx.find_tensor_in_main_tower(G, var_var_name)
#logger.info("In prediction, using {} instead of {} for {}".format( #logger.info("In prediction, using {} instead of {} for {}".format(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import inspect import inspect
......
...@@ -64,14 +64,14 @@ def LeakyReLU(x, alpha, name=None): ...@@ -64,14 +64,14 @@ def LeakyReLU(x, alpha, name=None):
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
# I'm not a layer, but I return a nonlinearity. # I'm not a layer, but I return a nonlinearity.
def BNReLU(is_training, **kwargs): def BNReLU(is_training=None, **kwargs):
""" """
:param is_traning: boolean :param is_traning: boolean
:param kwargs: args for BatchNorm :param kwargs: args for BatchNorm
:returns: an activation function that performs BN + ReLU (a too common combination) :returns: an activation function that performs BN + ReLU (a too common combination)
""" """
def BNReLU(x, name=None): def BNReLU(x, name=None):
x = BatchNorm('bn', x, is_training, **kwargs) x = BatchNorm('bn', x, use_local_stat=is_training, **kwargs)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
return x return x
return BNReLU return BNReLU
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment