Commit a51e2de4 authored by Yuxin Wu's avatar Yuxin Wu

a missing part of the last commit.

parent e3045eda
......@@ -6,23 +6,20 @@
""" Some common step callbacks. """
import tensorflow as tf
import re
from six.moves import zip
import tqdm
from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import (
MOVING_SUMMARY_VARS_KEY,
GLOBAL_STEP_INCR_VAR_NAME,
from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME,
LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value
from .base import Callback
__all__ = ['StepStatPrinter', 'MaintainStepCounter',
'SummaryMovingAverage', 'ProgressBar']
__all__ = ['StepTensorPrinter', 'MaintainStepCounter',
'ProgressBar']
class StepStatPrinter(Callback):
class StepTensorPrinter(Callback):
""" It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use
:func:`symbolic_functions.print_stat` or :func:`tf.Print` instead. """
......@@ -30,10 +27,10 @@ class StepStatPrinter(Callback):
def __init__(self, names):
"""
Args:
names(list): list of string, the names of the tensor to print.
names(list): list of string, the names of the tensors to print.
"""
names = [get_op_tensor_name(n)[1] for n in names]
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!")
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!")
self._names = names
def _extra_fetches(self):
......@@ -53,8 +50,11 @@ class MaintainStepCounter(Callback):
"""
def _setup_graph(self):
# ensure it exists
get_global_step_var()
self.gs_incr_var = self.trainer.sess.graph.get_tensor_by_name(GLOBAL_STEP_INCR_VAR_NAME)
gs_var = get_global_step_var()
with tf.name_scope(None):
self.gs_incr_var = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME)
self.local_step = tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME)
......@@ -68,37 +68,6 @@ class MaintainStepCounter(Callback):
return [self.gs_incr_var.op]
class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
"""
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95):
"""
Args:
collection(str): the collection of tensors to summarize. The
default would work with :func:`add_moving_summary`.
decay(float): the decay of the moving average.
"""
self._collection = collection
self._decay = decay
def _setup_graph(self):
tensors = set(tf.get_collection(self._collection))
# TODO will produce tower0/xxx. not elegant
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
self._decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
self.ema_op = avg_maintain_op
def _extra_fetches(self):
return [self.ema_op]
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
def _before_train(self):
......
......@@ -5,17 +5,23 @@
import tensorflow as tf
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME
from ..utils.naming import (
GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME,
LOCAL_STEP_VAR_NAME)
from ..utils import logger
from ..utils.argtools import memoized
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
'get_local_step_var',
'get_op_tensor_name',
'get_tensors_by_names',
'get_op_or_tensor_by_name',
'get_tf_version',
'get_name_scope_name'
'get_name_scope_name',
]
......@@ -56,8 +62,6 @@ def get_global_step_var():
var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0,
trainable=False, dtype=tf.int32)
# also create the incr operation
tf.assign_add(var, 1, name=GLOBAL_STEP_INCR_OP_NAME)
return var
......@@ -70,6 +74,15 @@ def get_global_step_value():
get_global_step_var())
@memoized
def get_local_step_var():
try:
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except KeyError:
logger.warn("get_local_step_var() is only available to use in callbacks!")
raise
def get_op_tensor_name(name):
"""
Will automatically determine if ``name`` is a tensor name (ends with ':x')
......@@ -110,14 +123,6 @@ def get_op_or_tensor_by_name(name):
return G.get_operation_by_name(name)
def get_tf_version():
"""
Returns:
int:
"""
return int(tf.__version__.split('.')[1])
def get_name_scope_name():
"""
Returns:
......
......@@ -6,7 +6,8 @@ import tensorflow as tf
from ..callbacks import (
Callbacks, SummaryMovingAverage,
StatPrinter, ProgressBar, MaintainStepCounter)
StatPrinter, ProgressBar,
MaintainStepCounter)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -84,7 +85,10 @@ class TrainConfig(object):
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()]
extra_callbacks = [
SummaryMovingAverage(),
ProgressBar(),
StatPrinter()]
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment