Commit 12cd6e6c authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] fix interface for distributed trainer.

parent a673974c
...@@ -24,7 +24,13 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -24,7 +24,13 @@ class NewSessionCreator(tf.train.SessionCreator):
""" """
self.target = target self.target = target
if config is None: if config is None:
# distributd trainer doesn't support user-provided config
# we set this attribute so that they can check
self.user_provided_config = False
config = get_default_sess_config() config = get_default_sess_config()
else:
self.user_provided_config = True
self.config = config self.config = config
self.graph = graph self.graph = graph
......
...@@ -101,8 +101,6 @@ class TrainConfig(object): ...@@ -101,8 +101,6 @@ class TrainConfig(object):
else: else:
self.session_creator = session_creator self.session_creator = session_creator
assert session_config is None, "Cannot set both session_creator and session_config!" assert session_config is None, "Cannot set both session_creator and session_config!"
# only used by DistributedTrainer for assertion!
self.session_config = session_config
if steps_per_epoch is None: if steps_per_epoch is None:
try: try:
......
...@@ -85,7 +85,7 @@ class DistributedTrainerReplicated(Trainer): ...@@ -85,7 +85,7 @@ class DistributedTrainerReplicated(Trainer):
def _set_session_creator(self): def _set_session_creator(self):
old_sess_creator = self._config.session_creator old_sess_creator = self._config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \ if not isinstance(old_sess_creator, NewSessionCreator) \
or self._config.session_config is not None: or old_sess_creator.user_provided_config:
raise ValueError( raise ValueError(
"Cannot set session_creator or session_config for distributed training! " "Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.") "To use a custom session config, pass it to tf.train.Server.")
......
...@@ -231,8 +231,7 @@ class SingleCostTrainer(Trainer): ...@@ -231,8 +231,7 @@ class SingleCostTrainer(Trainer):
These callbacks will be automatically added when you call `train()`. These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value. So you can usually ignore the return value.
""" """
assert not input.setup_done() input_callbacks = self._setup_input(inputs_desc, input)
input_callbacks = input.setup(inputs_desc)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks return self._internal_callbacks
...@@ -240,3 +239,7 @@ class SingleCostTrainer(Trainer): ...@@ -240,3 +239,7 @@ class SingleCostTrainer(Trainer):
@abstractmethod @abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass pass
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: trainers.py # File: trainers.py
import os import os
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
...@@ -17,7 +18,7 @@ from ..tfutils import get_global_step_var ...@@ -17,7 +18,7 @@ from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator from ..tfutils.distributed import get_distributed_session_creator
from ..input_source import QueueInput from ..input_source import QueueInput
from .base import Trainer, SingleCostTrainer from .base import SingleCostTrainer
__all__ = ['SimpleTrainer', __all__ = ['SimpleTrainer',
'QueueInputTrainer', 'QueueInputTrainer',
...@@ -32,8 +33,7 @@ class SimpleTrainer(SingleCostTrainer): ...@@ -32,8 +33,7 @@ class SimpleTrainer(SingleCostTrainer):
Single-GPU single-cost single-tower trainer. Single-GPU single-cost single-tower trainer.
""" """
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = SimpleBuilder().build( self.train_op = SimpleBuilder().build(input, get_cost_fn, get_opt_fn)
input, get_cost_fn, get_opt_fn)
return [] return []
...@@ -126,17 +126,13 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -126,17 +126,13 @@ class DistributedTrainerReplicated(SingleCostTrainer):
self.is_chief = False self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster)) logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
def train(self, def _setup_input(self, inputs_desc, input):
inputs_desc, input, get_cost_fn, get_opt_fn,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
if self.job_name == 'ps': if self.job_name == 'ps':
# ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index)) logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid())) logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713 self.server.join() # this function will never return tensorflow#4713
return raise RuntimeError("This is a bug in tensorpack. Server.join() for ps should never return!")
with override_to_local_variable(): with override_to_local_variable():
get_global_step_var() # gs should be local get_global_step_var() # gs should be local
...@@ -144,14 +140,8 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -144,14 +140,8 @@ class DistributedTrainerReplicated(SingleCostTrainer):
# TODO This is not good because we don't know from here # TODO This is not good because we don't know from here
# whether something should be global or local. We now assume # whether something should be global or local. We now assume
# they should be local. # they should be local.
input_callbacks = input.setup(inputs_desc) assert not input.setup_done()
return input.setup(inputs_desc)
train_callbacks = self.setup_graph(input, get_cost_fn, get_opt_fn)
Trainer.train(
self,
callbacks + input_callbacks + train_callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, initial_sync_op, model_sync_op = self._builder.build( self.train_op, initial_sync_op, model_sync_op = self._builder.build(
...@@ -174,9 +164,10 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -174,9 +164,10 @@ class DistributedTrainerReplicated(SingleCostTrainer):
return callbacks return callbacks
def initialize(self, session_creator, session_init): def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator): if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config:
raise ValueError( raise ValueError(
"Cannot set session_creator for distributed training! " "Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.") "To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize( super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init) get_distributed_session_creator(), session_init)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment