Commit d2939cf8 authored by Yuxin Wu's avatar Yuxin Wu

Builder with get_grads instead of input+get_cost

parent 2c5b1bec
......@@ -186,12 +186,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return tf.group(*queue_ops, name=name)
def build(self, input, get_cost_fn, get_opt_fn):
def build(self, get_grad_fn, get_opt_fn):
"""
Args:
input (InputSource): the input. Should have been setup.
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_grad_fn (-> [(grad, var)]):
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
......@@ -211,9 +209,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
get_opt_fn = memoized(get_opt_fn)
# Build the optimizer first, before entering any tower.
# This makes sure that learning_rate is a global variable (what we expect)
get_opt_fn()
get_grad_fn, _ = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
get_opt_fn() # TODO get_opt_fn called before main graph was built
# Ngpu * Nvar * 2
grad_list = DataParallelBuilder.build_on_towers(
......
......@@ -9,6 +9,8 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import get_current_tower_context
from ..input_source import InputSource
from ..models.regularize import regularize_cost_from_collection
......@@ -149,8 +151,25 @@ class ModelDesc(ModelDescBase):
def build_graph_get_cost(self, *inputs):
"""
Build the graph from inputs and return the cost tensor.
This is useful for most of the :class:`GraphBuilder` which expects
such a function.
"""
self.build_graph(inputs)
return self.get_cost()
def build_graph_get_grads(self, *inputs):
"""
Build the graph from inputs and return the grads.
This is useful for most of the :class:`GraphBuilder` which expects such a function.
Returns:
[(grad, var)]
"""
ctx = get_current_tower_context()
cost = self.build_graph_get_cost(*inputs)
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = self.get_optimizer()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
......@@ -8,19 +8,16 @@ import six
from six.moves import zip, range
from ..utils import logger
from ..utils.argtools import memoized
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context
from ..tfutils.tower import TowerContext
from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import FeedfreeInput
from .utils import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['GraphBuilder', 'SimpleBuilder',
__all__ = ['GraphBuilder',
'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder',
'SyncMultiGPUReplicatedBuilder', 'AsyncMultiGPUBuilder']
......@@ -32,35 +29,6 @@ class GraphBuilder(object):
pass
class SimpleBuilder(GraphBuilder):
"""
Single-cost single-optimizer single-tower training.
"""
def build(self, input, get_cost_fn, get_opt_fn):
"""
Args:
input (InputSource): the input. Should have been setup.
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
tf.Operation: the training op
"""
assert input.setup_done()
with TowerContext('', is_training=True) as ctx:
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
train_op = opt.apply_gradients(grads, name='min_op')
return train_op
class DataParallelBuilder(GraphBuilder):
def __init__(self, towers):
"""
......@@ -132,27 +100,6 @@ class DataParallelBuilder(GraphBuilder):
restore_collection(backup)
return ret
@staticmethod
def _make_fn(input, get_cost_fn, get_opt_fn):
# internal use only
assert input.setup_done(), "InputSource must have been setup before calling GraphBuilder!"
assert isinstance(input, FeedfreeInput), input
get_opt_fn = memoized(get_opt_fn)
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn, get_opt_fn
class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
"""
......@@ -195,12 +142,10 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
new_tower_grads.append((grad, v))
return new_tower_grads
def build(self, input, get_cost_fn, get_opt_fn):
def build(self, get_grad_fn, get_opt_fn):
"""
Args:
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_grad_fn (-> [(grad, var)]):
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
......@@ -213,8 +158,6 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
get_grad_fn, get_opt_fn = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
DataParallelBuilder._check_grad_list(grad_list)
......@@ -266,12 +209,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
# NVar * NGPU * 2
return new_tower_grads
def build(self, input, get_cost_fn, get_opt_fn):
def build(self, get_grad_fn, get_opt_fn):
"""
Args:
input (InputSource): the input. Should have been setup.
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_grad_fn (-> [(grad, var)]):
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
......@@ -285,8 +226,6 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
"""
raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
get_grad_fn, get_opt_fn = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
grad_list = DataParallelBuilder.build_on_towers(
self.towers,
get_grad_fn,
......@@ -356,12 +295,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
super(AsyncMultiGPUBuilder, self).__init__(towers)
self._scale_gradient = scale_gradient
def build(self, input, get_cost_fn, get_opt_fn):
def build(self, get_grad_fn, get_opt_fn):
"""
Args:
input (InputSource): the input. Should have been setup.
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_grad_fn (-> [(grad, var)]):
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
......@@ -376,8 +313,6 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
get_grad_fn, get_opt_fn = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
DataParallelBuilder._check_grad_list(grad_list)
......
......@@ -64,7 +64,9 @@ class DistributedTrainerReplicated(Trainer):
self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
# initial local_vars syncing
cb = RunOp(lambda: initial_sync_op,
......
......@@ -70,7 +70,9 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
self.train_op = SyncMultiGPUParameterServerBuilder(
self._config.tower, self._ps_device).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks)
......@@ -100,8 +102,11 @@ class SyncMultiGPUTrainerReplicated(Trainer):
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self._config.tower).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(
self._config.tower).build(
lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
cb = RunOp(
lambda: post_init_op,
......@@ -129,6 +134,8 @@ class AsyncMultiGPUTrainer(Trainer):
self.train_op = AsyncMultiGPUBuilder(
self._config.tower, self._scale_gradient).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks)
......@@ -10,13 +10,15 @@ import six
from abc import abstractmethod, ABCMeta
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.argtools import call_only_once, memoized
from ..input_source import FeedfreeInput
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
from ..train.base import StopTraining, TrainLoop
......@@ -232,9 +234,13 @@ class SingleCostTrainer(Trainer):
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Build the main training graph. Defaults to do nothing.
You can either override it in subclasses, or build the graph outside
the trainer.
Responsible for building the main training graph.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
......@@ -242,10 +248,11 @@ class SingleCostTrainer(Trainer):
So you can usually ignore the return value.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
self._internal_callbacks = input_callbacks + train_callbacks
self.inputs_desc = inputs_desc
self.get_cost_fn = get_cost_fn
return self._internal_callbacks
......@@ -257,3 +264,26 @@ class SingleCostTrainer(Trainer):
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
assert isinstance(input, FeedfreeInput), input
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
......@@ -10,10 +10,10 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TowerContext
from ..input_source import QueueInput
from ..graph_builder.training import (
SimpleBuilder,
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder)
......@@ -35,7 +35,10 @@ class SimpleTrainer(SingleCostTrainer):
Single-GPU single-cost single-tower trainer.
"""
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = SimpleBuilder().build(input, get_cost_fn, get_opt_fn)
with TowerContext('', is_training=True):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
return []
......@@ -60,7 +63,8 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
super(SyncMultiGPUTrainerParameterServer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(input, get_cost_fn, get_opt_fn)
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
......@@ -78,7 +82,8 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
super(AsyncMultiGPUTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(input, get_cost_fn, get_opt_fn)
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
......@@ -96,7 +101,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, post_init_op = self._builder.build(
input, get_cost_fn, get_opt_fn)
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
cb = RunOp(
post_init_op,
......@@ -147,7 +152,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
input, get_cost_fn, get_opt_fn)
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
callbacks = []
# initial local_vars syncing
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment