Commit 091568ec authored by Yuxin Wu's avatar Yuxin Wu

fix DistributedTrainer (fix #505)

parent 2be64ce0
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# File: predict.py # File: predict.py
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
...@@ -29,14 +28,6 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -29,14 +28,6 @@ class SimplePredictBuilder(GraphBuilder):
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0' device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
self._device = device self._device = device
@contextmanager
def _maybe_open_vs(self):
if len(self._vs_name):
with tf.variable_scope(self._vs_name):
yield
else:
yield
def build(self, input, tower_fn): def build(self, input, tower_fn):
""" """
Args: Args:
...@@ -51,7 +42,6 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -51,7 +42,6 @@ class SimplePredictBuilder(GraphBuilder):
self._ns_name, self._device)) self._ns_name, self._device))
with tf.device(self._device), \ with tf.device(self._device), \
self._maybe_open_vs(), \
TowerContext( TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name): self._ns_name, is_training=False, vs_name=self._vs_name):
inputs = input.get_input_tensors() inputs = input.get_input_tensors()
......
...@@ -34,7 +34,9 @@ __all__ = ['PlaceholderInput', 'FeedInput', ...@@ -34,7 +34,9 @@ __all__ = ['PlaceholderInput', 'FeedInput',
def _get_reset_callback(df): def _get_reset_callback(df):
return CallbackFactory(setup_graph=lambda _: df.reset_state()) ret = CallbackFactory(setup_graph=lambda _: df.reset_state())
ret.chief_only = False
return ret
class PlaceholderInput(InputSource): class PlaceholderInput(InputSource):
......
...@@ -124,8 +124,9 @@ class TowerContext(object): ...@@ -124,8 +124,9 @@ class TowerContext(object):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!" assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
curr_vs = tf.get_variable_scope() if self.is_training:
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!" curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "In training, cannot nest TowerContext with an existing variable scope!"
self._ctxs = self._get_scopes() self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard) self._ctxs.append(self._collection_guard)
......
...@@ -9,7 +9,7 @@ from ..input_source import ( ...@@ -9,7 +9,7 @@ from ..input_source import (
from .config import TrainConfig from .config import TrainConfig
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated from .trainers import SimpleTrainer
__all__ = ['launch_train_with_config', 'apply_default_prefetch'] __all__ = ['launch_train_with_config', 'apply_default_prefetch']
...@@ -77,12 +77,6 @@ def launch_train_with_config(config, trainer): ...@@ -77,12 +77,6 @@ def launch_train_with_config(config, trainer):
input = config.data or config.dataflow input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer, config.tower) input = apply_default_prefetch(input, trainer, config.tower)
if isinstance(trainer, DistributedTrainerReplicated) and \
config.session_config is not None:
raise ValueError(
"Cannot set session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
trainer.setup_graph( trainer.setup_graph(
inputs_desc, input, inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer) model._build_graph_get_cost, model.get_optimizer)
......
...@@ -165,7 +165,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -165,7 +165,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
logger.info("Running ps {}".format(self.server.server_def.task_index)) logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid())) logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this function will never return tensorflow#4713 self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug in tensorpack. Server.join() for ps should never return!") raise RuntimeError("This is a bug. Server.join() for ps should never return!")
with override_to_local_variable(): with override_to_local_variable():
get_global_step_var() # gs should be local get_global_step_var() # gs should be local
...@@ -204,7 +204,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -204,7 +204,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
"You are not allowed to set session_creator or session_config for distributed training! " "You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.") "To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize( super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init) get_distributed_session_creator(self.server), session_init)
@property @property
def _main_tower_vs_name(self): def _main_tower_vs_name(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment