Commit d2939cf8 authored by Yuxin Wu's avatar Yuxin Wu

Builder with get_grads instead of input+get_cost

parent 2c5b1bec
...@@ -186,12 +186,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -186,12 +186,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return tf.group(*queue_ops, name=name) return tf.group(*queue_ops, name=name)
def build(self, input, get_cost_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
Args: Args:
input (InputSource): the input. Should have been setup. get_grad_fn (-> [(grad, var)]):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns: Returns:
...@@ -211,9 +209,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -211,9 +209,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
# Build the optimizer first, before entering any tower. # Build the optimizer first, before entering any tower.
# This makes sure that learning_rate is a global variable (what we expect) # This makes sure that learning_rate is a global variable (what we expect)
get_opt_fn() get_opt_fn() # TODO get_opt_fn called before main graph was built
get_grad_fn, _ = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
# Ngpu * Nvar * 2 # Ngpu * Nvar * 2
grad_list = DataParallelBuilder.build_on_towers( grad_list = DataParallelBuilder.build_on_towers(
......
...@@ -9,6 +9,8 @@ import tensorflow as tf ...@@ -9,6 +9,8 @@ import tensorflow as tf
import six import six
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import get_current_tower_context
from ..input_source import InputSource from ..input_source import InputSource
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
...@@ -149,8 +151,25 @@ class ModelDesc(ModelDescBase): ...@@ -149,8 +151,25 @@ class ModelDesc(ModelDescBase):
def build_graph_get_cost(self, *inputs): def build_graph_get_cost(self, *inputs):
""" """
Build the graph from inputs and return the cost tensor. Build the graph from inputs and return the cost tensor.
This is useful for most of the :class:`GraphBuilder` which expects
such a function.
""" """
self.build_graph(inputs) self.build_graph(inputs)
return self.get_cost() return self.get_cost()
def build_graph_get_grads(self, *inputs):
"""
Build the graph from inputs and return the grads.
This is useful for most of the :class:`GraphBuilder` which expects such a function.
Returns:
[(grad, var)]
"""
ctx = get_current_tower_context()
cost = self.build_graph_get_cost(*inputs)
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = self.get_optimizer()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
...@@ -8,19 +8,16 @@ import six ...@@ -8,19 +8,16 @@ import six
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import FeedfreeInput
from .utils import LeastLoadedDeviceSetter, override_to_local_variable from .utils import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['GraphBuilder', 'SimpleBuilder', __all__ = ['GraphBuilder',
'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder', 'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder',
'SyncMultiGPUReplicatedBuilder', 'AsyncMultiGPUBuilder'] 'SyncMultiGPUReplicatedBuilder', 'AsyncMultiGPUBuilder']
...@@ -32,35 +29,6 @@ class GraphBuilder(object): ...@@ -32,35 +29,6 @@ class GraphBuilder(object):
pass pass
class SimpleBuilder(GraphBuilder):
"""
Single-cost single-optimizer single-tower training.
"""
def build(self, input, get_cost_fn, get_opt_fn):
"""
Args:
input (InputSource): the input. Should have been setup.
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
tf.Operation: the training op
"""
assert input.setup_done()
with TowerContext('', is_training=True) as ctx:
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
train_op = opt.apply_gradients(grads, name='min_op')
return train_op
class DataParallelBuilder(GraphBuilder): class DataParallelBuilder(GraphBuilder):
def __init__(self, towers): def __init__(self, towers):
""" """
...@@ -132,27 +100,6 @@ class DataParallelBuilder(GraphBuilder): ...@@ -132,27 +100,6 @@ class DataParallelBuilder(GraphBuilder):
restore_collection(backup) restore_collection(backup)
return ret return ret
@staticmethod
def _make_fn(input, get_cost_fn, get_opt_fn):
# internal use only
assert input.setup_done(), "InputSource must have been setup before calling GraphBuilder!"
assert isinstance(input, FeedfreeInput), input
get_opt_fn = memoized(get_opt_fn)
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn, get_opt_fn
class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
""" """
...@@ -195,12 +142,10 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -195,12 +142,10 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
new_tower_grads.append((grad, v)) new_tower_grads.append((grad, v))
return new_tower_grads return new_tower_grads
def build(self, input, get_cost_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
Args: Args:
input (InputSource): get_grad_fn (-> [(grad, var)]):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns: Returns:
...@@ -213,8 +158,6 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -213,8 +158,6 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
devices = [tf.train.replica_device_setter( devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices] worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
get_grad_fn, get_opt_fn = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices) grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
...@@ -266,12 +209,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -266,12 +209,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
# NVar * NGPU * 2 # NVar * NGPU * 2
return new_tower_grads return new_tower_grads
def build(self, input, get_cost_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
Args: Args:
input (InputSource): the input. Should have been setup. get_grad_fn (-> [(grad, var)]):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns: Returns:
...@@ -285,8 +226,6 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -285,8 +226,6 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
""" """
raw_devices = ['/gpu:{}'.format(k) for k in self.towers] raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
get_grad_fn, get_opt_fn = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
grad_list = DataParallelBuilder.build_on_towers( grad_list = DataParallelBuilder.build_on_towers(
self.towers, self.towers,
get_grad_fn, get_grad_fn,
...@@ -356,12 +295,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder): ...@@ -356,12 +295,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
super(AsyncMultiGPUBuilder, self).__init__(towers) super(AsyncMultiGPUBuilder, self).__init__(towers)
self._scale_gradient = scale_gradient self._scale_gradient = scale_gradient
def build(self, input, get_cost_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
Args: Args:
input (InputSource): the input. Should have been setup. get_grad_fn (-> [(grad, var)]):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable which takes a list of input tensor
and returns a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns: Returns:
...@@ -376,8 +313,6 @@ class AsyncMultiGPUBuilder(DataParallelBuilder): ...@@ -376,8 +313,6 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
devices = [tf.train.replica_device_setter( devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices] worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
get_grad_fn, get_opt_fn = DataParallelBuilder._make_fn(input, get_cost_fn, get_opt_fn)
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices) grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
......
...@@ -64,7 +64,9 @@ class DistributedTrainerReplicated(Trainer): ...@@ -64,7 +64,9 @@ class DistributedTrainerReplicated(Trainer):
self._config.callbacks.extend(cbs) self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build( self.train_op, initial_sync_op, model_sync_op = self._builder.build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
# initial local_vars syncing # initial local_vars syncing
cb = RunOp(lambda: initial_sync_op, cb = RunOp(lambda: initial_sync_op,
......
...@@ -70,7 +70,9 @@ class SyncMultiGPUTrainerParameterServer(Trainer): ...@@ -70,7 +70,9 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
self.train_op = SyncMultiGPUParameterServerBuilder( self.train_op = SyncMultiGPUParameterServerBuilder(
self._config.tower, self._ps_device).build( self._config.tower, self._ps_device).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks) self._config.callbacks.extend(callbacks)
...@@ -100,8 +102,11 @@ class SyncMultiGPUTrainerReplicated(Trainer): ...@@ -100,8 +102,11 @@ class SyncMultiGPUTrainerReplicated(Trainer):
def _setup(self): def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc()) callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self._config.tower).build( self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) self._config.tower).build(
lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
cb = RunOp( cb = RunOp(
lambda: post_init_op, lambda: post_init_op,
...@@ -129,6 +134,8 @@ class AsyncMultiGPUTrainer(Trainer): ...@@ -129,6 +134,8 @@ class AsyncMultiGPUTrainer(Trainer):
self.train_op = AsyncMultiGPUBuilder( self.train_op = AsyncMultiGPUBuilder(
self._config.tower, self._scale_gradient).build( self._config.tower, self._scale_gradient).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) lambda: self.model.build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks) self._config.callbacks.extend(callbacks)
...@@ -10,13 +10,15 @@ import six ...@@ -10,13 +10,15 @@ import six
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once, memoized
from ..input_source import FeedfreeInput
from ..callbacks import Callback, Callbacks from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter from ..callbacks.steps import MaintainStepCounter
from ..train.base import StopTraining, TrainLoop from ..train.base import StopTraining, TrainLoop
...@@ -232,9 +234,13 @@ class SingleCostTrainer(Trainer): ...@@ -232,9 +234,13 @@ class SingleCostTrainer(Trainer):
@call_only_once @call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn): def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
""" """
Build the main training graph. Defaults to do nothing. Responsible for building the main training graph.
You can either override it in subclasses, or build the graph outside
the trainer. Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns: Returns:
[Callback]: a (possibly empty) list of callbacks needed for training. [Callback]: a (possibly empty) list of callbacks needed for training.
...@@ -242,10 +248,11 @@ class SingleCostTrainer(Trainer): ...@@ -242,10 +248,11 @@ class SingleCostTrainer(Trainer):
So you can usually ignore the return value. So you can usually ignore the return value.
""" """
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc) get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
self._internal_callbacks = input_callbacks + train_callbacks
self.inputs_desc = inputs_desc self.inputs_desc = inputs_desc
self.get_cost_fn = get_cost_fn self.get_cost_fn = get_cost_fn
return self._internal_callbacks return self._internal_callbacks
...@@ -257,3 +264,26 @@ class SingleCostTrainer(Trainer): ...@@ -257,3 +264,26 @@ class SingleCostTrainer(Trainer):
def _setup_input(self, inputs_desc, input): def _setup_input(self, inputs_desc, input):
assert not input.setup_done() assert not input.setup_done()
return input.setup(inputs_desc) return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
assert isinstance(input, FeedfreeInput), input
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
...@@ -10,10 +10,10 @@ from ..tfutils.sesscreate import NewSessionCreator ...@@ -10,10 +10,10 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger from ..utils import logger
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TowerContext
from ..input_source import QueueInput from ..input_source import QueueInput
from ..graph_builder.training import ( from ..graph_builder.training import (
SimpleBuilder,
SyncMultiGPUParameterServerBuilder, SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder, SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder) AsyncMultiGPUBuilder)
...@@ -35,7 +35,10 @@ class SimpleTrainer(SingleCostTrainer): ...@@ -35,7 +35,10 @@ class SimpleTrainer(SingleCostTrainer):
Single-GPU single-cost single-tower trainer. Single-GPU single-cost single-tower trainer.
""" """
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = SimpleBuilder().build(input, get_cost_fn, get_opt_fn) with TowerContext('', is_training=True):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
return [] return []
...@@ -60,7 +63,8 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -60,7 +63,8 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
super(SyncMultiGPUTrainerParameterServer, self).__init__() super(SyncMultiGPUTrainerParameterServer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(input, get_cost_fn, get_opt_fn) self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return [] return []
...@@ -78,7 +82,8 @@ class AsyncMultiGPUTrainer(SingleCostTrainer): ...@@ -78,7 +82,8 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
super(AsyncMultiGPUTrainer, self).__init__() super(AsyncMultiGPUTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(input, get_cost_fn, get_opt_fn) self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return [] return []
...@@ -96,7 +101,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -96,7 +101,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, post_init_op = self._builder.build( self.train_op, post_init_op = self._builder.build(
input, get_cost_fn, get_opt_fn) self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
cb = RunOp( cb = RunOp(
post_init_op, post_init_op,
...@@ -147,7 +152,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -147,7 +152,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, initial_sync_op, model_sync_op = self._builder.build( self.train_op, initial_sync_op, model_sync_op = self._builder.build(
input, get_cost_fn, get_opt_fn) self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
callbacks = [] callbacks = []
# initial local_vars syncing # initial local_vars syncing
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment