Commit 8059ee40 authored by Yuxin Wu's avatar Yuxin Wu

Check ModelDescBase instead of ModelDesc in PredictConfig (fix #361)

parent 930af0b6
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
import six import six
from ..graph_builder import ModelDesc from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
...@@ -24,7 +24,7 @@ class PredictConfig(object): ...@@ -24,7 +24,7 @@ class PredictConfig(object):
): ):
""" """
Args: Args:
model (ModelDesc): the model to use. model (ModelDescBase): the model to use.
session_creator (tf.train.SessionCreator): how to create the session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSessionCreator()`. session. Defaults to :class:`sesscreate.NewSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session. session_init (SessionInit): how to initialize variables of the session.
...@@ -40,7 +40,7 @@ class PredictConfig(object): ...@@ -40,7 +40,7 @@ class PredictConfig(object):
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.model = model self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDescBase)
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
......
...@@ -191,11 +191,8 @@ def dump_chkpt_vars(model_path): ...@@ -191,11 +191,8 @@ def dump_chkpt_vars(model_path):
def is_training_name(name): def is_training_name(name):
""" """
Guess if a name belongs to a training-only variables. **Guess** if this variable is only used in training.
Only used internally to avoid too many logging. Do not use it. Only used internally to avoid too many logging. Do not use it.
Returns:
bool: Guess whether this tensor is something only used in training.
""" """
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES? # TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names() # TODO or use get_slot_names()
...@@ -210,6 +207,6 @@ def is_training_name(name): ...@@ -210,6 +207,6 @@ def is_training_name(name):
return True return True
if name.endswith('/Adagrad'): if name.endswith('/Adagrad'):
return True return True
if name.startswith('/EMA'): if name.startswith('EMA/'): # all the moving average summaries
return True return True
return False return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment