Commit 091568ec authored by Yuxin Wu's avatar Yuxin Wu

fix DistributedTrainer (fix #505)

parent 2be64ce0
......@@ -3,7 +3,6 @@
# File: predict.py
import tensorflow as tf
from contextlib import contextmanager
from ..utils import logger
from ..tfutils.tower import TowerContext
......@@ -29,14 +28,6 @@ class SimplePredictBuilder(GraphBuilder):
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
self._device = device
@contextmanager
def _maybe_open_vs(self):
if len(self._vs_name):
with tf.variable_scope(self._vs_name):
yield
else:
yield
def build(self, input, tower_fn):
"""
Args:
......@@ -51,7 +42,6 @@ class SimplePredictBuilder(GraphBuilder):
self._ns_name, self._device))
with tf.device(self._device), \
self._maybe_open_vs(), \
TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name):
inputs = input.get_input_tensors()
......
......@@ -34,7 +34,9 @@ __all__ = ['PlaceholderInput', 'FeedInput',
def _get_reset_callback(df):
return CallbackFactory(setup_graph=lambda _: df.reset_state())
ret = CallbackFactory(setup_graph=lambda _: df.reset_state())
ret.chief_only = False
return ret
class PlaceholderInput(InputSource):
......
......@@ -124,8 +124,9 @@ class TowerContext(object):
global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
if self.is_training:
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "In training, cannot nest TowerContext with an existing variable scope!"
self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard)
......
......@@ -9,7 +9,7 @@ from ..input_source import (
from .config import TrainConfig
from .tower import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated
from .trainers import SimpleTrainer
__all__ = ['launch_train_with_config', 'apply_default_prefetch']
......@@ -77,12 +77,6 @@ def launch_train_with_config(config, trainer):
input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer, config.tower)
if isinstance(trainer, DistributedTrainerReplicated) and \
config.session_config is not None:
raise ValueError(
"Cannot set session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
trainer.setup_graph(
inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer)
......
......@@ -165,7 +165,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug in tensorpack. Server.join() for ps should never return!")
raise RuntimeError("This is a bug. Server.join() for ps should never return!")
with override_to_local_variable():
get_global_step_var() # gs should be local
......@@ -204,7 +204,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init)
get_distributed_session_creator(self.server), session_init)
@property
def _main_tower_vs_name(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment