Commit 8059ee40 authored by Yuxin Wu's avatar Yuxin Wu

Check ModelDescBase instead of ModelDesc in PredictConfig (fix #361)

parent 930af0b6
......@@ -5,7 +5,7 @@
import tensorflow as tf
import six
from ..graph_builder import ModelDesc
from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator
......@@ -24,7 +24,7 @@ class PredictConfig(object):
):
"""
Args:
model (ModelDesc): the model to use.
model (ModelDescBase): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session.
......@@ -40,7 +40,7 @@ class PredictConfig(object):
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.model = model
assert_type(self.model, ModelDesc)
assert_type(self.model, ModelDescBase)
if session_init is None:
session_init = JustCurrentSession()
......
......@@ -191,11 +191,8 @@ def dump_chkpt_vars(model_path):
def is_training_name(name):
"""
Guess if a name belongs to a training-only variables.
**Guess** if this variable is only used in training.
Only used internally to avoid too many logging. Do not use it.
Returns:
bool: Guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names()
......@@ -210,6 +207,6 @@ def is_training_name(name):
return True
if name.endswith('/Adagrad'):
return True
if name.startswith('/EMA'):
if name.startswith('EMA/'): # all the moving average summaries
return True
return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment