Commit 0addcdc6 authored by Yuxin Wu's avatar Yuxin Wu

Fix session creator bug for distributed trainer.

parent c99d013a
...@@ -86,13 +86,12 @@ class ModelDescBase(object): ...@@ -86,13 +86,12 @@ class ModelDescBase(object):
:returns: a list of InputDesc :returns: a list of InputDesc
""" """
# TODO only use InputSource in the future? Now only used in predictor_factory
def build_graph(self, inputs): def build_graph(self, inputs):
""" """
Build the whole symbolic graph. Build the whole symbolic graph.
Args: Args:
inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`, inputs (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``. that match the list of :class:`InputDesc` defined by ``_get_inputs``.
""" """
if isinstance(inputs, InputSource): if isinstance(inputs, InputSource):
......
...@@ -278,7 +278,6 @@ class Trainer(object): ...@@ -278,7 +278,6 @@ class Trainer(object):
Returns: Returns:
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
# TODO move the logic to factory?
return self.predictor_factory.get_predictor(input_names, output_names, tower) return self.predictor_factory.get_predictor(input_names, output_names, tower)
@property @property
......
...@@ -9,8 +9,7 @@ from ..callbacks import ( ...@@ -9,8 +9,7 @@ from ..callbacks import (
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger from ..utils import logger
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession, SessionInit)
get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource from ..input_source import InputSource
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
...@@ -98,7 +97,7 @@ class TrainConfig(object): ...@@ -98,7 +97,7 @@ class TrainConfig(object):
if session_config is not None: if session_config is not None:
self.session_creator = NewSessionCreator(config=session_config) self.session_creator = NewSessionCreator(config=session_config)
else: else:
self.session_creator = NewSessionCreator(config=get_default_sess_config()) self.session_creator = NewSessionCreator(config=None)
else: else:
self.session_creator = session_creator self.session_creator = session_creator
assert session_config is None, "Cannot set both session_creator and session_config!" assert session_config is None, "Cannot set both session_creator and session_config!"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment