Commit fc9e45b0 authored by Yuxin Wu's avatar Yuxin Wu

fix bn bug

parent 87569ca6
......@@ -43,4 +43,4 @@ To visualize the agent:
./DQN.py --rom breakout.bin --task play --load pretrained.model
```
A3C code will be released at the end of August.
A3C code will be released in about a week.
......@@ -32,15 +32,16 @@ class Model(ModelDesc):
InputVar(tf.int32, [None], 'label')
]
def _build_graph(self, input_vars, is_training):
def _build_graph(self, input_vars):
image, label = input_vars
is_training = get_current_tower_context().is_training
keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training:
tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller
with argscope(Conv2D, nl=BNReLU(is_training), use_bias=False, kernel_shape=3):
with argscope(Conv2D, nl=BNReLU(), use_bias=False, kernel_shape=3):
logits = LinearWrap(image) \
.Conv2D('conv1.1', out_channel=64) \
.Conv2D('conv1.2', out_channel=64) \
......
......@@ -22,8 +22,8 @@ class Model(ModelDesc):
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputVar(tf.int32, (None,), 'label') ]
def _build_graph(self, input_vars, is_training):
is_training = bool(is_training)
def _build_graph(self, input_vars):
is_training = get_current_tower_context().is_training
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = input_vars
......
......@@ -72,18 +72,18 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else:
assert not use_local_stat
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
if ctx.is_main_tower:
# not training, but main tower. need to create the vars
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
# use statistics in another tower
G = tf.get_default_graph()
# figure out the var name
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean) + ':0'
var_var_name = ema.average_name(batch_var) + ':0'
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name)
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod
import re
import tensorflow as tf
from collections import namedtuple
import inspect
......
......@@ -64,14 +64,14 @@ def LeakyReLU(x, alpha, name=None):
return tf.mul(x, 0.5, name=name)
# I'm not a layer, but I return a nonlinearity.
def BNReLU(is_training, **kwargs):
def BNReLU(is_training=None, **kwargs):
"""
:param is_traning: boolean
:param kwargs: args for BatchNorm
:returns: an activation function that performs BN + ReLU (a too common combination)
"""
def BNReLU(x, name=None):
x = BatchNorm('bn', x, is_training, **kwargs)
x = BatchNorm('bn', x, use_local_stat=is_training, **kwargs)
x = tf.nn.relu(x, name=name)
return x
return BNReLU
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment