Commit 82bf74c9 authored by Yuxin Wu's avatar Yuxin Wu

graph builder for simple trainer

parent 94eace54
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: training.py
from abc import ABCMeta, abstractmethod
import tensorflow as tf
import six
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext
@six.add_metaclass(ABCMeta)
class GraphBuilder(object):
@abstractmethod
def build(*args, **kwargs):
pass
class SimpleGraphBuilder(GraphBuilder):
"""
Build the graph for single-cost single-optimizer single-tower training.
"""
def build(self, input, get_cost_fn, get_opt_fn):
"""
Args:
input (InputSource): should have been setup already
get_cost_fn ([tf.Tensor] -> tf.Tensor): a callable,
taking several tensors as input and returns a cost tensor.
get_opt_fn (None -> tf.train.Optimizer): a callable that returns an optimizer
Returns:
tf.Operation: the training op
"""
with TowerContext('', is_training=True) as ctx:
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
train_op = opt.apply_gradients(grads, name='min_op')
return train_op
...@@ -118,6 +118,7 @@ class TowerContext(object): ...@@ -118,6 +118,7 @@ class TowerContext(object):
assert ns == self._name, \ assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \ "Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
+ " You may need a different name for the tower!" + " You may need a different name for the tower!"
return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext global _CurrentTowerContext
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
from .base import Trainer from .base import Trainer
from ..utils import logger from ..utils import logger
from ..tfutils import TowerContext
from ..graph_builder.input_source import FeedInput from ..graph_builder.input_source import FeedInput
from ..graph_builder.training import SimpleGraphBuilder
__all__ = ['SimpleTrainer'] __all__ = ['SimpleTrainer']
...@@ -54,11 +54,12 @@ class SimpleTrainer(Trainer): ...@@ -54,11 +54,12 @@ class SimpleTrainer(Trainer):
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
with TowerContext('', is_training=True):
model.build_graph(input) def get_cost(*inputs):
_, grads = model.get_cost_and_grad() model.build_graph(inputs)
opt = model.get_optimizer() return model.get_cost()
train_op = opt.apply_gradients(grads, name='min_op')
train_op = SimpleGraphBuilder().build(input, get_cost, model.get_optimizer)
return train_op, cbs return train_op, cbs
def _setup(self): def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment