Commit 20d1af11 authored by Yuxin Wu's avatar Yuxin Wu

simplify ModeSaver. just save without rename

parent 3e97f126
...@@ -8,7 +8,6 @@ import shutil ...@@ -8,7 +8,6 @@ import shutil
from .base import Triggerable from .base import Triggerable
from ..utils import logger from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...@@ -43,27 +42,12 @@ class ModelSaver(Triggerable): ...@@ -43,27 +42,12 @@ class ModelSaver(Triggerable):
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
self.path = os.path.join(self.checkpoint_dir, 'model') self.path = os.path.join(self.checkpoint_dir, 'model')
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars), var_list=vars,
max_to_keep=self.keep_recent, max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq, keep_checkpoint_every_n_hours=self.keep_freq,
write_version=tf.train.SaverDef.V2) write_version=tf.train.SaverDef.V2)
self.meta_graph_written = False self.meta_graph_written = False
@staticmethod
def _get_var_dict(vars):
var_dict = {}
for v in vars:
name = get_savename_from_varname(v.name)
if name not in var_dict:
if name != v.name:
logger.info(
"[ModelSaver] {} renamed to {} when saving model.".format(v.name, name))
var_dict[name] = v
else:
logger.info("[ModelSaver] Variable {} won't be saved \
due to an alternative in a different tower".format(v.name, var_dict[name].name))
return var_dict
def _trigger(self): def _trigger(self):
try: try:
if not self.meta_graph_written: if not self.meta_graph_written:
......
...@@ -143,7 +143,7 @@ class ParamRestore(SessionInit): ...@@ -143,7 +143,7 @@ class ParamRestore(SessionInit):
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables]) variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
intersect = variable_names & param_names intersect = variable_names & param_names
...@@ -156,9 +156,7 @@ class ParamRestore(SessionInit): ...@@ -156,9 +156,7 @@ class ParamRestore(SessionInit):
for k in sorted(param_names - variable_names): for k in sorted(param_names - variable_names):
logger.warn("Variable {} in the dict not found in the graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
[v for v in variables if
get_savename_from_varname(v.name) in intersect])
logger.info("Restoring from dict ...") logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
......
...@@ -7,7 +7,6 @@ import six ...@@ -7,7 +7,6 @@ import six
import os import os
import tensorflow as tf import tensorflow as tf
from collections import defaultdict from collections import defaultdict
import re
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
from ..utils.naming import PREDICT_TOWER from ..utils.naming import PREDICT_TOWER
...@@ -34,8 +33,6 @@ def get_savename_from_varname( ...@@ -34,8 +33,6 @@ def get_savename_from_varname(
logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER)) logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph # don't overwrite anything in the current prediction graph
return None return None
if 'tower' in name:
name = re.sub('tower[p0-9]+/', '', name)
if varname_prefix is not None \ if varname_prefix is not None \
and name.startswith(varname_prefix): and name.startswith(varname_prefix):
name = name[len(varname_prefix) + 1:] name = name[len(varname_prefix) + 1:]
...@@ -56,8 +53,7 @@ class SessionUpdate(object): ...@@ -56,8 +53,7 @@ class SessionUpdate(object):
self.sess = sess self.sess = sess
self.name_map = defaultdict(list) self.name_map = defaultdict(list)
for v in vars_to_update: for v in vars_to_update:
savename = get_savename_from_varname(v.name) self.name_map[v.name].append(v)
self.name_map[savename].append(v)
@staticmethod @staticmethod
def load_value_to_var(var, val, strict=False): def load_value_to_var(var, val, strict=False):
...@@ -133,11 +129,7 @@ def dump_session_params(path): ...@@ -133,11 +129,7 @@ def dump_session_params(path):
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!" assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
result = {} result = {}
for v in var: for v in var:
name = get_savename_from_varname(v.name) result[v.name] = v.eval()
if name in result:
logger.info("Variable {} would be stored instead of another with \
the same name".format(v.name))
result[name] = v.eval()
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys())) logger.info(str(result.keys()))
np.save(path, result) np.save(path, result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment