Commit cc89b105 authored by Yuxin Wu's avatar Yuxin Wu

EMA callback don't create variables itself. add old SaverRestore to be fast

parent 20d1af11
......@@ -68,6 +68,10 @@ def get_config():
class WGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with 5:1 ratio.
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. using the existing GANTrainer) also works well.
"""
def __init__(self, config):
self._input_method = QueueInput(config.dataflow)
super(WGANTrainer, self).__init__(config)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import sys
import pprint
from tensorpack.tfutils.varmanip import get_checkpoint_path
path = get_checkpoint_path(sys.argv[1])
reader = tf.train.NewCheckpointReader(path)
pprint.pprint(reader.get_variable_to_shape_map())
......@@ -4,10 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re
from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from ..tfutils.common import get_global_step_var
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback
__all__ = ['MovingAverageSummary']
......@@ -17,28 +15,18 @@ class MovingAverageSummary(Callback):
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
"""
def __init__(self, collection=MOVING_SUMMARY_VARS_KEY, decay=0.95):
def __init__(self, collection=MOVING_SUMMARY_OPS_KEY):
"""
Args:
collection(str): the collection of tensors to summarize. The
default would work with :func:`add_moving_summary`.
decay(float): the decay of the moving average.
collection(str): the collection of EMA-maintaining ops.
The default would work with :func:`add_moving_summary()`,
but you can use some others.
"""
self._collection = collection
self._decay = decay
def _setup_graph(self):
tensors = set(tf.get_collection(self._collection))
# TODO will produce tower0/xxx. not elegant
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
self._decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
self.ema_op = avg_maintain_op
ops = tf.get_collection(self._collection)
self.ema_op = tf.group(*ops, name='summary_moving_averages')
def _extra_fetches(self):
return [self.ema_op]
......@@ -3,22 +3,20 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
from abc import abstractmethod, ABCMeta
import numpy as np
import tensorflow as tf
import six
from ..utils import logger, PREDICT_TOWER
from ..utils import logger
from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader']
@six.add_metaclass(ABCMeta)
class SessionInit(object):
""" Base class for utilities to initialize a session. """
def init(self, sess):
......@@ -30,23 +28,31 @@ class SessionInit(object):
"""
self._init(sess)
@abstractmethod
def _init(self, sess):
self._setup_graph()
self._run_init(sess)
def _setup_graph(self):
pass
def _run_init(self, sess):
pass
class JustCurrentSession(SessionInit):
""" This is a no-op placeholder"""
def _init(self, sess):
pass
pass
class NewSession(SessionInit):
"""
Initialize global variables by their initializer.
"""
def _init(self, sess):
sess.run(tf.global_variables_initializer())
def _setup_graph(self):
self.op = tf.global_variables_initializer()
def _run_init(self, sess):
sess.run(self.op)
class CheckpointReaderAdapter(object):
......@@ -58,7 +64,7 @@ class CheckpointReaderAdapter(object):
self._reader = reader
m = self._reader.get_variable_to_shape_map()
self._map = {k if k.endswith(':0') else k + ':0': v
for k, v in m.iteritems()}
for k, v in six.iteritems(m)}
def get_variable_to_shape_map(self):
return self._map
......@@ -74,23 +80,77 @@ class CheckpointReaderAdapter(object):
def has_tensor(self, name):
return name in self._map
# some checkpoint might not have ':0'
def get_real_name(self, name):
if self._reader.has_tensor(name):
return name
assert self.has_tensor(name)
return name[:-2]
class SaverRestore(SessionInit):
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
"""
def __init__(self, model_path, prefix=None):
"""
Args:
model_path (str): path to the model (model-xxxx) or a ``checkpoint`` file.
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
"""
model_path = get_checkpoint_path(model_path)
self.path = model_path
self.prefix = prefix
def _init(self, sess):
def _setup_graph(self):
dic = self._get_restore_dict()
self.saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
def _run_init(self, sess):
logger.info("Restoring checkpoint from {} ...".format(self.path))
self.saver.restore(sess, self.path)
@staticmethod
def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path)
reader = CheckpointReaderAdapter(reader) # use an adapter to standardize the name
ckpt_vars = reader.get_variable_to_shape_map().keys()
return reader, set(ckpt_vars)
def _get_restore_dict(self):
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
graph_vars = tf.global_variables()
var_dict = {}
chkpt_vars_used = set()
for v in graph_vars:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if reader.has_tensor(name):
ckpt_name = reader.get_real_name(name)
assert ckpt_name not in var_dict, "Restore conflict: {} and {}".format(v.name, var_dict[ckpt_name].name)
var_dict[ckpt_name] = v
chkpt_vars_used.add(name)
else:
vname = v.op.name
if not is_training_name(vname):
logger.warn("Variable {} in the graph not found in checkpoint!".format(vname))
if len(chkpt_vars_used) < len(chkpt_vars):
unused = chkpt_vars - chkpt_vars_used
for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict
class SaverRestoreRelaxed(SaverRestore):
""" Same as :class:`SaverRestore`, but has more relaxed constraints.
It allows upcasting certain variables, or reshape certain
variables when there is a mismatch that can be fixed.
Another advantage is that it doesn't add any new ops to the graph.
But it is also slower than :class:`SaverRestore`.
"""
def _run_init(self, sess):
logger.info(
"Restoring checkpoint from {} ...".format(self.path))
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
......@@ -114,18 +174,6 @@ class SaverRestore(SessionInit):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
@staticmethod
def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path)
reader = CheckpointReaderAdapter(reader)
ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
if v.startswith(PREDICT_TOWER):
logger.error("Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved.".format(v.name))
return reader, set(ckpt_vars)
class ParamRestore(SessionInit):
"""
......@@ -140,7 +188,7 @@ class ParamRestore(SessionInit):
# use varname (with :0) for consistency
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess):
def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO
variable_names = set([k.name for k in variables])
......@@ -182,6 +230,14 @@ class ChainInit(SessionInit):
for i in self.inits:
i.init(sess)
def _setup_graph(self):
for i in self.inits:
i._setup_graph()
def _run_init(self, sess):
for i in self.inits:
i._run_init(sess)
def get_model_loader(filename):
"""
......
......@@ -7,9 +7,10 @@ import tensorflow as tf
import re
from ..utils import log_deprecated
from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
from .common import get_global_step_var
__all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary']
......@@ -98,13 +99,21 @@ def add_param_summary(*summary_lists):
perform(p, act)
def add_moving_summary(v, *args):
def add_moving_summary(v, *args, **kwargs):
"""
Args:
v (tf.Tensor or list): tensor or list of tensors to summary. Must have
scalar type.
args: tensors to summary (support positional arguments)
decay (float): the decay rate. Defaults to 0.95.
collection (str): the name of the collection to add EMA-maintaining ops.
The default will work together with the default
:class:`MovingAverageSummary` callback.
"""
decay = kwargs.pop('decay', 0.95)
coll = kwargs.pop('collection', MOVING_SUMMARY_OPS_KEY)
assert len(kwargs) == 0, "Unknown arguments: " + str(kwargs)
ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower:
return
......@@ -112,5 +121,15 @@ def add_moving_summary(v, *args):
v = [v]
v.extend(args)
for x in v:
assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
# TODO will produce tower0/xxx?
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(v)
for c in v:
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
tf.add_to_collection(coll, avg_maintain_op)
......@@ -147,6 +147,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
assert os.path.isfile(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
......
......@@ -138,6 +138,8 @@ class Trainer(object):
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
self.config.session_init._setup_graph()
def after_init(_, __):
logger.info("Graph variables initialized.")
scaffold = tf.train.Scaffold(
......@@ -149,7 +151,7 @@ class Trainer(object):
scaffold=scaffold, config=self.config.session_config),
hooks=None)
self.sess = self.monitored_sess._tf_sess()
self.config.session_init.init(self.sess)
self.config.session_init._run_init(self.sess)
@abstractmethod
def _setup(self):
......
......@@ -29,18 +29,19 @@ def mkdir_p(dirname):
raise e
def download(url, dir):
def download(url, dir, filename=None):
"""
Download URL to a directory. Will figure out the filename automatically
from URL.
"""
mkdir_p(dir)
fname = url.split('/')[-1]
fpath = os.path.join(dir, fname)
if filename is None:
filename = url.split('/')[-1]
fpath = os.path.join(dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname,
(filename,
min(float(count * block_size) / total_size,
1.0) * 100.0))
sys.stdout.flush()
......@@ -54,7 +55,7 @@ def download(url, dir):
assert size > 0, "Download an empty file!"
sys.stdout.write('\n')
# TODO human-readable size
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
print('Succesfully downloaded ' + filename + " " + str(size) + ' bytes.')
return fpath
......
......@@ -17,12 +17,12 @@ LOCAL_STEP_VAR_NAME = 'local_step:0'
PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
# metainfo for input tensors
INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
# export all upper case variables
all_local_names = locals().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment