Commit 756dbc70 authored by Yuxin Wu's avatar Yuxin Wu

update var name for batch_norm

parent 117fb29f
...@@ -71,7 +71,7 @@ class Model(ModelDesc): ...@@ -71,7 +71,7 @@ class Model(ModelDesc):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True): with argscope(Conv2D, nl=PReLU.f, use_bias=True):
l = (LinearWrap(image) return (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5) .Conv2D('conv0', out_channel=32, kernel_shape=5)
.MaxPooling('pool0', 2) .MaxPooling('pool0', 2)
.Conv2D('conv1', out_channel=32, kernel_shape=5) .Conv2D('conv1', out_channel=32, kernel_shape=5)
...@@ -87,7 +87,6 @@ class Model(ModelDesc): ...@@ -87,7 +87,6 @@ class Model(ModelDesc):
.FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name)) .FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
.FullyConnected('fct', NUM_ACTIONS, nl=tf.identity)()) .FullyConnected('fct', NUM_ACTIONS, nl=tf.identity)())
return l
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs, is_training):
state, action, reward, next_state, isOver = inputs state, action, reward, next_state, isOver = inputs
......
...@@ -55,17 +55,17 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -55,17 +55,17 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
emaname = 'EMA' emaname = 'EMA'
in_main_tower = not batch_mean.name.startswith('towerp') in_main_tower = not batch_mean.name.startswith('towerp')
if in_main_tower: if in_main_tower:
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema_apply_op = ema.apply([batch_mean, batch_var]) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else: else:
# use training-statistics in prediction # use training-statistics in prediction
assert not use_local_stat assert not use_local_stat
# XXX have to do this again to get actual name. see issue: with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# https://github.com/tensorflow/tensorflow/issues/2740 ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
G = tf.get_default_graph() G = tf.get_default_graph()
# find training statistics in training tower # find training statistics in training tower
......
...@@ -10,7 +10,7 @@ from ..utils import * ...@@ -10,7 +10,7 @@ from ..utils import *
from . import get_global_step_var from . import get_global_step_var
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
'summary_moving_average'] 'add_moving_summary', 'summary_moving_average']
def create_summary(name, v): def create_summary(name, v):
""" """
...@@ -42,8 +42,8 @@ def add_param_summary(summary_lists): ...@@ -42,8 +42,8 @@ def add_param_summary(summary_lists):
""" """
Add summary for all trainable variables matching the regex Add summary for all trainable variables matching the regex
:param summary_lists: list of (regex, [list of action to perform]). :param summary_lists: list of (regex, [list of summary type to perform]).
Action can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms' Type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
""" """
def perform(var, action): def perform(var, action):
ndim = var.get_shape().ndims ndim = var.get_shape().ndims
...@@ -66,7 +66,7 @@ def add_param_summary(summary_lists): ...@@ -66,7 +66,7 @@ def add_param_summary(summary_lists):
tf.scalar_summary(name + '/rms', tf.scalar_summary(name + '/rms',
tf.sqrt(tf.reduce_mean(tf.square(var)))) tf.sqrt(tf.reduce_mean(tf.square(var))))
return return
raise RuntimeError("Unknown action {}".format(action)) raise RuntimeError("Unknown summary type: {}".format(action))
import re import re
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
...@@ -79,6 +79,9 @@ def add_param_summary(summary_lists): ...@@ -79,6 +79,9 @@ def add_param_summary(summary_lists):
for act in actions: for act in actions:
perform(p, act) perform(p, act)
def add_moving_summary(v):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, v)
def summary_moving_average(): def summary_moving_average():
""" Create a MovingAverage op and summary for all variables in """ Create a MovingAverage op and summary for all variables in
MOVING_SUMMARY_VARS_KEY. MOVING_SUMMARY_VARS_KEY.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment