Commit 12cd6e6c authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] fix interface for distributed trainer.

parent a673974c
......@@ -24,7 +24,13 @@ class NewSessionCreator(tf.train.SessionCreator):
"""
self.target = target
if config is None:
# distributd trainer doesn't support user-provided config
# we set this attribute so that they can check
self.user_provided_config = False
config = get_default_sess_config()
else:
self.user_provided_config = True
self.config = config
self.graph = graph
......
......@@ -101,8 +101,6 @@ class TrainConfig(object):
else:
self.session_creator = session_creator
assert session_config is None, "Cannot set both session_creator and session_config!"
# only used by DistributedTrainer for assertion!
self.session_config = session_config
if steps_per_epoch is None:
try:
......
......@@ -85,7 +85,7 @@ class DistributedTrainerReplicated(Trainer):
def _set_session_creator(self):
old_sess_creator = self._config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or self._config.session_config is not None:
or old_sess_creator.user_provided_config:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
......
......@@ -231,8 +231,7 @@ class SingleCostTrainer(Trainer):
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
assert not input.setup_done()
input_callbacks = input.setup(inputs_desc)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
......@@ -240,3 +239,7 @@ class SingleCostTrainer(Trainer):
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: trainers.py
import os
from ..callbacks.graph import RunOp
......@@ -17,7 +18,7 @@ from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..input_source import QueueInput
from .base import Trainer, SingleCostTrainer
from .base import SingleCostTrainer
__all__ = ['SimpleTrainer',
'QueueInputTrainer',
......@@ -32,8 +33,7 @@ class SimpleTrainer(SingleCostTrainer):
Single-GPU single-cost single-tower trainer.
"""
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = SimpleBuilder().build(
input, get_cost_fn, get_opt_fn)
self.train_op = SimpleBuilder().build(input, get_cost_fn, get_opt_fn)
return []
......@@ -126,17 +126,13 @@ class DistributedTrainerReplicated(SingleCostTrainer):
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
def train(self,
inputs_desc, input, get_cost_fn, get_opt_fn,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
def _setup_input(self, inputs_desc, input):
if self.job_name == 'ps':
# ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713
return
self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug in tensorpack. Server.join() for ps should never return!")
with override_to_local_variable():
get_global_step_var() # gs should be local
......@@ -144,14 +140,8 @@ class DistributedTrainerReplicated(SingleCostTrainer):
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
input_callbacks = input.setup(inputs_desc)
train_callbacks = self.setup_graph(input, get_cost_fn, get_opt_fn)
Trainer.train(
self,
callbacks + input_callbacks + train_callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
assert not input.setup_done()
return input.setup(inputs_desc)
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
......@@ -174,9 +164,10 @@ class DistributedTrainerReplicated(SingleCostTrainer):
return callbacks
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator):
if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config:
raise ValueError(
"Cannot set session_creator for distributed training! "
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment