Commit 20d1af11 authored by Yuxin Wu's avatar Yuxin Wu

simplify ModeSaver. just save without rename

parent 3e97f126
......@@ -8,7 +8,6 @@ import shutil
from .base import Triggerable
from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -43,27 +42,12 @@ class ModelSaver(Triggerable):
vars.extend(tf.get_collection(key))
self.path = os.path.join(self.checkpoint_dir, 'model')
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
var_list=vars,
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
write_version=tf.train.SaverDef.V2)
self.meta_graph_written = False
@staticmethod
def _get_var_dict(vars):
var_dict = {}
for v in vars:
name = get_savename_from_varname(v.name)
if name not in var_dict:
if name != v.name:
logger.info(
"[ModelSaver] {} renamed to {} when saving model.".format(v.name, name))
var_dict[name] = v
else:
logger.info("[ModelSaver] Variable {} won't be saved \
due to an alternative in a different tower".format(v.name, var_dict[name].name))
return var_dict
def _trigger(self):
try:
if not self.meta_graph_written:
......
......@@ -143,7 +143,7 @@ class ParamRestore(SessionInit):
def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables])
variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms))
intersect = variable_names & param_names
......@@ -156,9 +156,7 @@ class ParamRestore(SessionInit):
for k in sorted(param_names - variable_names):
logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess,
[v for v in variables if
get_savename_from_varname(v.name) in intersect])
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
......
......@@ -7,7 +7,6 @@ import six
import os
import tensorflow as tf
from collections import defaultdict
import re
import numpy as np
from ..utils import logger
from ..utils.naming import PREDICT_TOWER
......@@ -34,8 +33,6 @@ def get_savename_from_varname(
logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph
return None
if 'tower' in name:
name = re.sub('tower[p0-9]+/', '', name)
if varname_prefix is not None \
and name.startswith(varname_prefix):
name = name[len(varname_prefix) + 1:]
......@@ -56,8 +53,7 @@ class SessionUpdate(object):
self.sess = sess
self.name_map = defaultdict(list)
for v in vars_to_update:
savename = get_savename_from_varname(v.name)
self.name_map[savename].append(v)
self.name_map[v.name].append(v)
@staticmethod
def load_value_to_var(var, val, strict=False):
......@@ -133,11 +129,7 @@ def dump_session_params(path):
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
result = {}
for v in var:
name = get_savename_from_varname(v.name)
if name in result:
logger.info("Variable {} would be stored instead of another with \
the same name".format(v.name))
result[name] = v.eval()
result[v.name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment