Commit a51e2de4 authored by Yuxin Wu's avatar Yuxin Wu

a missing part of the last commit.

parent e3045eda
...@@ -6,23 +6,20 @@ ...@@ -6,23 +6,20 @@
""" Some common step callbacks. """ """ Some common step callbacks. """
import tensorflow as tf import tensorflow as tf
import re
from six.moves import zip from six.moves import zip
import tqdm import tqdm
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import ( from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME,
MOVING_SUMMARY_VARS_KEY, LOCAL_STEP_OP_NAME)
GLOBAL_STEP_INCR_VAR_NAME,
LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value
from .base import Callback from .base import Callback
__all__ = ['StepStatPrinter', 'MaintainStepCounter', __all__ = ['StepTensorPrinter', 'MaintainStepCounter',
'SummaryMovingAverage', 'ProgressBar'] 'ProgressBar']
class StepStatPrinter(Callback): class StepTensorPrinter(Callback):
""" It prints the value of some tensors in each step. """ It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use It's just a demo of how trigger_step works but you should in general use
:func:`symbolic_functions.print_stat` or :func:`tf.Print` instead. """ :func:`symbolic_functions.print_stat` or :func:`tf.Print` instead. """
...@@ -30,10 +27,10 @@ class StepStatPrinter(Callback): ...@@ -30,10 +27,10 @@ class StepStatPrinter(Callback):
def __init__(self, names): def __init__(self, names):
""" """
Args: Args:
names(list): list of string, the names of the tensor to print. names(list): list of string, the names of the tensors to print.
""" """
names = [get_op_tensor_name(n)[1] for n in names] names = [get_op_tensor_name(n)[1] for n in names]
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!") logger.warn("Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!")
self._names = names self._names = names
def _extra_fetches(self): def _extra_fetches(self):
...@@ -53,11 +50,14 @@ class MaintainStepCounter(Callback): ...@@ -53,11 +50,14 @@ class MaintainStepCounter(Callback):
""" """
def _setup_graph(self): def _setup_graph(self):
# ensure it exists # ensure it exists
get_global_step_var() gs_var = get_global_step_var()
self.gs_incr_var = self.trainer.sess.graph.get_tensor_by_name(GLOBAL_STEP_INCR_VAR_NAME) with tf.name_scope(None):
self.local_step = tf.mod( self.gs_incr_var = tf.assign_add(
self.gs_incr_var, self.trainer.config.step_per_epoch, gs_var, 1,
name=LOCAL_STEP_OP_NAME) name=GLOBAL_STEP_INCR_OP_NAME)
self.local_step = tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME)
def _before_train(self): def _before_train(self):
gs_val = get_global_step_value() gs_val = get_global_step_value()
...@@ -68,37 +68,6 @@ class MaintainStepCounter(Callback): ...@@ -68,37 +68,6 @@ class MaintainStepCounter(Callback):
return [self.gs_incr_var.op] return [self.gs_incr_var.op]
class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
"""
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95):
"""
Args:
collection(str): the collection of tensors to summarize. The
default would work with :func:`add_moving_summary`.
decay(float): the decay of the moving average.
"""
self._collection = collection
self._decay = decay
def _setup_graph(self):
tensors = set(tf.get_collection(self._collection))
# TODO will produce tower0/xxx. not elegant
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
self._decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
self.ema_op = avg_maintain_op
def _extra_fetches(self):
return [self.ema_op]
class ProgressBar(Callback): class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """ """ A progress bar based on tqdm. Enabled by default. """
def _before_train(self): def _before_train(self):
......
...@@ -5,17 +5,23 @@ ...@@ -5,17 +5,23 @@
import tensorflow as tf import tensorflow as tf
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME from ..utils.naming import (
GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME,
LOCAL_STEP_VAR_NAME)
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_local_step_var',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
'get_op_or_tensor_by_name', 'get_op_or_tensor_by_name',
'get_tf_version', 'get_name_scope_name',
'get_name_scope_name'
] ]
...@@ -56,8 +62,6 @@ def get_global_step_var(): ...@@ -56,8 +62,6 @@ def get_global_step_var():
var = tf.get_variable(GLOBAL_STEP_OP_NAME, var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0, initializer=0,
trainable=False, dtype=tf.int32) trainable=False, dtype=tf.int32)
# also create the incr operation
tf.assign_add(var, 1, name=GLOBAL_STEP_INCR_OP_NAME)
return var return var
...@@ -70,6 +74,15 @@ def get_global_step_value(): ...@@ -70,6 +74,15 @@ def get_global_step_value():
get_global_step_var()) get_global_step_var())
@memoized
def get_local_step_var():
try:
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except KeyError:
logger.warn("get_local_step_var() is only available to use in callbacks!")
raise
def get_op_tensor_name(name): def get_op_tensor_name(name):
""" """
Will automatically determine if ``name`` is a tensor name (ends with ':x') Will automatically determine if ``name`` is a tensor name (ends with ':x')
...@@ -110,14 +123,6 @@ def get_op_or_tensor_by_name(name): ...@@ -110,14 +123,6 @@ def get_op_or_tensor_by_name(name):
return G.get_operation_by_name(name) return G.get_operation_by_name(name)
def get_tf_version():
"""
Returns:
int:
"""
return int(tf.__version__.split('.')[1])
def get_name_scope_name(): def get_name_scope_name():
""" """
Returns: Returns:
......
...@@ -6,7 +6,8 @@ import tensorflow as tf ...@@ -6,7 +6,8 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, SummaryMovingAverage, Callbacks, SummaryMovingAverage,
StatPrinter, ProgressBar, MaintainStepCounter) StatPrinter, ProgressBar,
MaintainStepCounter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -84,7 +85,10 @@ class TrainConfig(object): ...@@ -84,7 +85,10 @@ class TrainConfig(object):
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter() callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list) assert_type(callbacks, list)
if extra_callbacks is None: if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()] extra_callbacks = [
SummaryMovingAverage(),
ProgressBar(),
StatPrinter()]
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks) self.callbacks = Callbacks(self.callbacks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment