Commit cc89b105 authored by Yuxin Wu's avatar Yuxin Wu

EMA callback don't create variables itself. add old SaverRestore to be fast

parent 20d1af11
...@@ -68,6 +68,10 @@ def get_config(): ...@@ -68,6 +68,10 @@ def get_config():
class WGANTrainer(FeedfreeTrainerBase): class WGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with 5:1 ratio.
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. using the existing GANTrainer) also works well.
"""
def __init__(self, config): def __init__(self, config):
self._input_method = QueueInput(config.dataflow) self._input_method = QueueInput(config.dataflow)
super(WGANTrainer, self).__init__(config) super(WGANTrainer, self).__init__(config)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import sys
import pprint
from tensorpack.tfutils.varmanip import get_checkpoint_path
path = get_checkpoint_path(sys.argv[1])
reader = tf.train.NewCheckpointReader(path)
pprint.pprint(reader.get_variable_to_shape_map())
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import re
from ..utils.naming import MOVING_SUMMARY_VARS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from ..tfutils.common import get_global_step_var
from .base import Callback from .base import Callback
__all__ = ['MovingAverageSummary'] __all__ = ['MovingAverageSummary']
...@@ -17,28 +15,18 @@ class MovingAverageSummary(Callback): ...@@ -17,28 +15,18 @@ class MovingAverageSummary(Callback):
""" Maintain the moving average of the tensors """ Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default. in every step, and summarize them. Enabled by default.
""" """
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95): def __init__(self, collection=MOVING_SUMMARY_OPS_KEY):
""" """
Args: Args:
collection(str): the collection of tensors to summarize. The collection(str): the collection of EMA-maintaining ops.
default would work with :func:`add_moving_summary`. The default would work with :func:`add_moving_summary()`,
decay(float): the decay of the moving average. but you can use some others.
""" """
self._collection = collection self._collection = collection
self._decay = decay
def _setup_graph(self): def _setup_graph(self):
tensors = set(tf.get_collection(self._collection)) ops = tf.get_collection(self._collection)
self.ema_op = tf.group(*ops, name='summary_moving_averages')
# TODO will produce tower0/xxx. not elegant
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
self._decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
self.ema_op = avg_maintain_op
def _extra_fetches(self): def _extra_fetches(self):
return [self.ema_op] return [self.ema_op]
...@@ -3,22 +3,20 @@ ...@@ -3,22 +3,20 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os import os
from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils import logger, PREDICT_TOWER from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname, from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path) is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader'] 'JustCurrentSession', 'get_model_loader']
@six.add_metaclass(ABCMeta)
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a session. """ """ Base class for utilities to initialize a session. """
def init(self, sess): def init(self, sess):
...@@ -30,23 +28,31 @@ class SessionInit(object): ...@@ -30,23 +28,31 @@ class SessionInit(object):
""" """
self._init(sess) self._init(sess)
@abstractmethod
def _init(self, sess): def _init(self, sess):
self._setup_graph()
self._run_init(sess)
def _setup_graph(self):
pass
def _run_init(self, sess):
pass pass
class JustCurrentSession(SessionInit): class JustCurrentSession(SessionInit):
""" This is a no-op placeholder""" """ This is a no-op placeholder"""
def _init(self, sess): pass
pass
class NewSession(SessionInit): class NewSession(SessionInit):
""" """
Initialize global variables by their initializer. Initialize global variables by their initializer.
""" """
def _init(self, sess): def _setup_graph(self):
sess.run(tf.global_variables_initializer()) self.op = tf.global_variables_initializer()
def _run_init(self, sess):
sess.run(self.op)
class CheckpointReaderAdapter(object): class CheckpointReaderAdapter(object):
...@@ -58,7 +64,7 @@ class CheckpointReaderAdapter(object): ...@@ -58,7 +64,7 @@ class CheckpointReaderAdapter(object):
self._reader = reader self._reader = reader
m = self._reader.get_variable_to_shape_map() m = self._reader.get_variable_to_shape_map()
self._map = {k if k.endswith(':0') else k + ':0': v self._map = {k if k.endswith(':0') else k + ':0': v
for k, v in m.iteritems()} for k, v in six.iteritems(m)}
def get_variable_to_shape_map(self): def get_variable_to_shape_map(self):
return self._map return self._map
...@@ -74,23 +80,77 @@ class CheckpointReaderAdapter(object): ...@@ -74,23 +80,77 @@ class CheckpointReaderAdapter(object):
def has_tensor(self, name): def has_tensor(self, name):
return name in self._map return name in self._map
# some checkpoint might not have ':0'
def get_real_name(self, name):
if self._reader.has_tensor(name):
return name
assert self.has_tensor(name)
return name[:-2]
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`. Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
""" """
def __init__(self, model_path, prefix=None): def __init__(self, model_path, prefix=None):
""" """
Args: Args:
model_path (str): path to the model (model-xxxx) or a ``checkpoint`` file. model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
""" """
model_path = get_checkpoint_path(model_path) model_path = get_checkpoint_path(model_path)
self.path = model_path self.path = model_path
self.prefix = prefix self.prefix = prefix
def _init(self, sess): def _setup_graph(self):
dic = self._get_restore_dict()
self.saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
def _run_init(self, sess):
logger.info("Restoring checkpoint from {} ...".format(self.path))
self.saver.restore(sess, self.path)
@staticmethod
def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path)
reader = CheckpointReaderAdapter(reader) # use an adapter to standardize the name
ckpt_vars = reader.get_variable_to_shape_map().keys()
return reader, set(ckpt_vars)
def _get_restore_dict(self):
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
graph_vars = tf.global_variables()
var_dict = {}
chkpt_vars_used = set()
for v in graph_vars:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if reader.has_tensor(name):
ckpt_name = reader.get_real_name(name)
assert ckpt_name not in var_dict, "Restore conflict: {} and {}".format(v.name, var_dict[ckpt_name].name)
var_dict[ckpt_name] = v
chkpt_vars_used.add(name)
else:
vname = v.op.name
if not is_training_name(vname):
logger.warn("Variable {} in the graph not found in checkpoint!".format(vname))
if len(chkpt_vars_used) < len(chkpt_vars):
unused = chkpt_vars - chkpt_vars_used
for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict
class SaverRestoreRelaxed(SaverRestore):
""" Same as :class:`SaverRestore`, but has more relaxed constraints.
It allows upcasting certain variables, or reshape certain
variables when there is a mismatch that can be fixed.
Another advantage is that it doesn't add any new ops to the graph.
But it is also slower than :class:`SaverRestore`.
"""
def _run_init(self, sess):
logger.info( logger.info(
"Restoring checkpoint from {} ...".format(self.path)) "Restoring checkpoint from {} ...".format(self.path))
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
...@@ -114,18 +174,6 @@ class SaverRestore(SessionInit): ...@@ -114,18 +174,6 @@ class SaverRestore(SessionInit):
if not is_training_name(name): if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name)) logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
@staticmethod
def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path)
reader = CheckpointReaderAdapter(reader)
ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
if v.startswith(PREDICT_TOWER):
logger.error("Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved.".format(v.name))
return reader, set(ckpt_vars)
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
""" """
...@@ -140,7 +188,7 @@ class ParamRestore(SessionInit): ...@@ -140,7 +188,7 @@ class ParamRestore(SessionInit):
# use varname (with :0) for consistency # use varname (with :0) for consistency
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO
variable_names = set([k.name for k in variables]) variable_names = set([k.name for k in variables])
...@@ -182,6 +230,14 @@ class ChainInit(SessionInit): ...@@ -182,6 +230,14 @@ class ChainInit(SessionInit):
for i in self.inits: for i in self.inits:
i.init(sess) i.init(sess)
def _setup_graph(self):
for i in self.inits:
i._setup_graph()
def _run_init(self, sess):
for i in self.inits:
i._run_init(sess)
def get_model_loader(filename): def get_model_loader(filename):
""" """
......
...@@ -7,9 +7,10 @@ import tensorflow as tf ...@@ -7,9 +7,10 @@ import tensorflow as tf
import re import re
from ..utils import log_deprecated from ..utils import log_deprecated
from ..utils.naming import MOVING_SUMMARY_VARS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context from .tower import get_current_tower_context
from .symbolic_functions import rms from .symbolic_functions import rms
from .common import get_global_step_var
__all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary'] 'add_moving_summary']
...@@ -98,13 +99,21 @@ def add_param_summary(*summary_lists): ...@@ -98,13 +99,21 @@ def add_param_summary(*summary_lists):
perform(p, act) perform(p, act)
def add_moving_summary(v, *args): def add_moving_summary(v, *args, **kwargs):
""" """
Args: Args:
v (tf.Tensor or list): tensor or list of tensors to summary. Must have v (tf.Tensor or list): tensor or list of tensors to summary. Must have
scalar type. scalar type.
args: tensors to summary (support positional arguments) args: tensors to summary (support positional arguments)
decay (float): the decay rate. Defaults to 0.95.
collection (str): the name of the collection to add EMA-maintaining ops.
The default will work together with the default
:class:`MovingAverageSummary` callback.
""" """
decay = kwargs.pop('decay', 0.95)
coll = kwargs.pop('collection', MOVING_SUMMARY_OPS_KEY)
assert len(kwargs) == 0, "Unknown arguments: " + str(kwargs)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower: if ctx is not None and not ctx.is_main_training_tower:
return return
...@@ -112,5 +121,15 @@ def add_moving_summary(v, *args): ...@@ -112,5 +121,15 @@ def add_moving_summary(v, *args):
v = [v] v = [v]
v.extend(args) v.extend(args)
for x in v: for x in v:
assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) # TODO will produce tower0/xxx?
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(v)
for c in v:
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
tf.add_to_collection(coll, avg_maintain_op)
...@@ -147,6 +147,7 @@ def get_checkpoint_path(model_path): ...@@ -147,6 +147,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path: if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142 model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint': if os.path.basename(model_path) == 'checkpoint':
assert os.path.isfile(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path)) model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2 # to be consistent with either v1 or v2
......
...@@ -138,6 +138,8 @@ class Trainer(object): ...@@ -138,6 +138,8 @@ class Trainer(object):
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
self.config.session_init._setup_graph()
def after_init(_, __): def after_init(_, __):
logger.info("Graph variables initialized.") logger.info("Graph variables initialized.")
scaffold = tf.train.Scaffold( scaffold = tf.train.Scaffold(
...@@ -149,7 +151,7 @@ class Trainer(object): ...@@ -149,7 +151,7 @@ class Trainer(object):
scaffold=scaffold, config=self.config.session_config), scaffold=scaffold, config=self.config.session_config),
hooks=None) hooks=None)
self.sess = self.monitored_sess._tf_sess() self.sess = self.monitored_sess._tf_sess()
self.config.session_init.init(self.sess) self.config.session_init._run_init(self.sess)
@abstractmethod @abstractmethod
def _setup(self): def _setup(self):
......
...@@ -29,18 +29,19 @@ def mkdir_p(dirname): ...@@ -29,18 +29,19 @@ def mkdir_p(dirname):
raise e raise e
def download(url, dir): def download(url, dir, filename=None):
""" """
Download URL to a directory. Will figure out the filename automatically Download URL to a directory. Will figure out the filename automatically
from URL. from URL.
""" """
mkdir_p(dir) mkdir_p(dir)
fname = url.split('/')[-1] if filename is None:
fpath = os.path.join(dir, fname) filename = url.split('/')[-1]
fpath = os.path.join(dir, filename)
def _progress(count, block_size, total_size): def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname, (filename,
min(float(count * block_size) / total_size, min(float(count * block_size) / total_size,
1.0) * 100.0)) 1.0) * 100.0))
sys.stdout.flush() sys.stdout.flush()
...@@ -54,7 +55,7 @@ def download(url, dir): ...@@ -54,7 +55,7 @@ def download(url, dir):
assert size > 0, "Download an empty file!" assert size > 0, "Download an empty file!"
sys.stdout.write('\n') sys.stdout.write('\n')
# TODO human-readable size # TODO human-readable size
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.') print('Succesfully downloaded ' + filename + " " + str(size) + ' bytes.')
return fpath return fpath
......
...@@ -17,12 +17,12 @@ LOCAL_STEP_VAR_NAME = 'local_step:0' ...@@ -17,12 +17,12 @@ LOCAL_STEP_VAR_NAME = 'local_step:0'
PREDICT_TOWER = 'towerp' PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
# metainfo for input tensors # metainfo for input tensors
INPUTS_KEY = 'INPUTS_METAINFO' INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment