Commit a59e46cd authored by Yuxin Wu's avatar Yuxin Wu

[WIP] completely move summary_moving_average to step callback

parent 4cd01111
...@@ -14,7 +14,7 @@ import cv2 ...@@ -14,7 +14,7 @@ import cv2
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as CFG, use_global_argument from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
......
...@@ -8,13 +8,12 @@ import numpy as np ...@@ -8,13 +8,12 @@ import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, TowerContext, from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput, ModelDesc) get_global_step_var, QueueInput, ModelDesc)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
class GANModelDesc(ModelDesc): class GANModelDesc(ModelDesc):
def collect_variables(self): def collect_variables(self):
"""Extract variables by prefix """Extract variables by prefix
""" """
......
...@@ -14,7 +14,7 @@ import cv2 ...@@ -14,7 +14,7 @@ import cv2
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
......
...@@ -36,8 +36,13 @@ This is a visualization from tensorboard. Left to right: original, ground truth, ...@@ -36,8 +36,13 @@ This is a visualization from tensorboard. Left to right: original, ground truth,
## InfoGAN-mnist.py ## InfoGAN-mnist.py
Reproduce one mnist experiement in InfoGAN. Reproduce the mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution, and 2 latent variables corresponding to an "uniform distributioN" and maximizing mutual information, It assumes 10 latent variables corresponding to a categorical distribution, 2 latent variables corresponding to a uniform distribution.
the network learns to map the 10 variables to 10 digits and the other two latent variables to rotation and thickness in a completely unsupervised way. It then maximizes mutual information between these latent variables and the image, and learns interpretable latent representation.
![infogan](demo/InfoGAN-mnist.jpg) ![infogan](demo/InfoGAN-mnist.jpg)
* Left: 10 latent variables corresponding to 10 digits.
* Middle: 1 continuous latent variable controlled the rotation.
* Right: another continuous latent variable controlled the thickness.
...@@ -5,10 +5,13 @@ ...@@ -5,10 +5,13 @@
""" Some common step callbacks. """ """ Some common step callbacks. """
import tensorflow as tf
import re
from six.moves import zip from six.moves import zip
from ..utils import logger from ..utils import logger
from ..tfutils.common import get_op_tensor_name from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from ..tfutils.summary import summary_moving_average from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .base import Callback from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage'] __all__ = ['StepStatPrinter', 'SummaryMovingAverage']
...@@ -28,16 +31,37 @@ class StepStatPrinter(Callback): ...@@ -28,16 +31,37 @@ class StepStatPrinter(Callback):
return self._names return self._names
def _trigger_step(self, *args): def _trigger_step(self, *args):
assert len(args) == len(self._names), len(args)
for n, v in zip(self._names, args): for n, v in zip(self._names, args):
logger.info("{}: {}".format(n, v)) logger.info("{}: {}".format(n, v))
class SummaryMovingAverage(Callback): class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors added by :func:`summary.add_moving_summary` """ Maintain the moving average of the tensors
in every step, and summarize them. in every step, and summarize them.
""" """
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95):
"""
Args:
collection(str): the collection of tensors to summarize. The
default would work with :func:`add_moving_summary`.
decay(float): the decay of the moving average.
"""
self._collection = collection
self._decay = decay
def _setup_graph(self): def _setup_graph(self):
self.ema_op = summary_moving_average() tensors = set(tf.get_collection(self._collection))
# TODO will produce tower0/xxx. not elegant
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
self._decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
self.ema_op = avg_maintain_op
def _extra_fetches(self): def _extra_fetches(self):
return [self.ema_op] return [self.ema_op]
...@@ -6,15 +6,13 @@ import six ...@@ -6,15 +6,13 @@ import six
import tensorflow as tf import tensorflow as tf
import re import re
from ..utils.argtools import memoized
from ..utils import logger from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_VARS_KEY from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from .tower import get_current_tower_context from .tower import get_current_tower_context
from . import get_global_step_var
from .symbolic_functions import rms from .symbolic_functions import rms
__all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary', 'summary_moving_average'] 'add_moving_summary']
def create_scalar_summary(name, v): def create_scalar_summary(name, v):
...@@ -116,29 +114,3 @@ def add_moving_summary(v, *args): ...@@ -116,29 +114,3 @@ def add_moving_summary(v, *args):
for x in v: for x in v:
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
@memoized
def summary_moving_average(tensors=None):
"""
Create a MovingAverage Op and add summary Op for all the moving averages.
This is called by the trainer.
Args:
tensors(list): list of tf.Tensor to summary. hefaults to the
collection ````MOVING_SUMMARY_VARS_KEY``.
Returns:
tf.Operation: an op to maintain these average.
"""
if tensors is None:
tensors = set(tf.get_collection(MOVING_SUMMARY_VARS_KEY))
# TODO will produce tower0/xxx. not elegant
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
0.95, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
return avg_maintain_op
...@@ -177,7 +177,10 @@ class Trainer(object): ...@@ -177,7 +177,10 @@ class Trainer(object):
if self.coord.should_stop(): if self.coord.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass fetch_data = self.run_step() # implemented by subclass
if fetch_data: if fetch_data is None:
# the old Trainer
callbacks.trigger_step()
else:
callbacks.trigger_step(*fetch_data) callbacks.trigger_step(*fetch_data)
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self.trigger_epoch() self.trigger_epoch()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment