Commit 0addcdc6 authored by Yuxin Wu's avatar Yuxin Wu

Fix session creator bug for distributed trainer.

parent c99d013a
......@@ -86,13 +86,12 @@ class ModelDescBase(object):
:returns: a list of InputDesc
"""
# TODO only use InputSource in the future? Now only used in predictor_factory
def build_graph(self, inputs):
"""
Build the whole symbolic graph.
Args:
inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`,
inputs (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
if isinstance(inputs, InputSource):
......
......@@ -278,7 +278,6 @@ class Trainer(object):
Returns:
an :class:`OnlinePredictor`.
"""
# TODO move the logic to factory?
return self.predictor_factory.get_predictor(input_names, output_names, tower)
@property
......
......@@ -9,8 +9,7 @@ from ..callbacks import (
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils import (JustCurrentSession, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource
from ..utils.develop import log_deprecated
......@@ -98,7 +97,7 @@ class TrainConfig(object):
if session_config is not None:
self.session_creator = NewSessionCreator(config=session_config)
else:
self.session_creator = NewSessionCreator(config=get_default_sess_config())
self.session_creator = NewSessionCreator(config=None)
else:
self.session_creator = session_creator
assert session_config is None, "Cannot set both session_creator and session_config!"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment