Commit 4e644290 authored by Yuxin Wu's avatar Yuxin Wu

Add a function to apply general prefetch policies

parent fff3f2d3
......@@ -239,8 +239,10 @@ class SingleCostTrainer(Trainer):
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
......
......@@ -5,18 +5,35 @@
import tensorflow as tf
from ..input_source import (
FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
from ..train.config import TrainConfig
from .base import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated
__all__ = ['launch_train_with_config', 'TrainConfig']
__all__ = ['launch_train_with_config', 'TrainConfig', 'apply_default_prefetch']
def _maybe_gpu_prefetch(input, towers, gpu_prefetch):
# seem to only improve on >1 GPUs
if len(towers) > 1 and gpu_prefetch:
def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
"""
Apply a set of default rules to make a fast :class:`InputSource`.
Args:
input_source_or_dataflow(InputSource | DataFlow):
trainer (Trainer):
towers ([int]): list of GPU ids.
"""
if not isinstance(input_source_or_dataflow, InputSource):
# to mimic same behavior of the old trainer interface
if type(trainer) == SimpleTrainer:
input = FeedInput(input_source_or_dataflow)
else:
input = QueueInput(input_source_or_dataflow)
else:
input = input_source_or_dataflow
if len(towers) > 1:
# seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer)
assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInputWrapper, DummyConstantInput)):
......@@ -26,7 +43,8 @@ def _maybe_gpu_prefetch(input, towers, gpu_prefetch):
def launch_train_with_config(config, trainer):
"""
To mimic the old training interface, with a trainer and a config.
Train with a :class:`TrainConfig` and a new version of :class:`Trainer`, to
mimic the old training interface.
Args:
config (TrainConfig):
......@@ -49,18 +67,8 @@ def launch_train_with_config(config, trainer):
model = config.model
inputs_desc = model.get_inputs_desc()
input = config.data
# some check & input wrappers to mimic same behavior of the old trainer interface
if input is None:
if type(trainer) == SimpleTrainer:
input = FeedInput(config.dataflow)
else:
input = QueueInput(config.dataflow)
if config.nr_tower > 1:
assert not isinstance(trainer, SimpleTrainer)
input = _maybe_gpu_prefetch(input, config.tower, True)
input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer, config.tower)
if isinstance(trainer, DistributedTrainerReplicated) and \
config.session_config is not None:
......@@ -72,10 +80,6 @@ def launch_train_with_config(config, trainer):
inputs_desc, input,
model.build_graph_get_cost, model.get_optimizer)
trainer.train(
config.callbacks,
config.monitors,
config.session_creator,
config.session_init,
config.steps_per_epoch,
config.starting_epoch,
config.max_epoch)
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment