Commit 728eb800 authored by Yuxin Wu's avatar Yuxin Wu

Allow some gradient options to be settable (#568, #427)

parent a3cc3a18
......@@ -122,6 +122,18 @@ class SingleCostTrainer(TowerTrainer):
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
COLOCATE_GRADIENTS_WITH_OPS = True
"""
See `tf.gradients`. This might affect performance when backward op does
not support the device of forward op.
"""
GATE_GRADIENTS = False
"""See `tf.gradients`. """
AGGREGATION_METHOD = tf.AggregationMethod.DEFAULT
"""See `tf.gradients`. """
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
......@@ -182,7 +194,9 @@ class SingleCostTrainer(TowerTrainer):
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
gate_gradients=self.GATE_GRADIENTS,
colocate_gradients_with_ops=self.COLOCATE_GRADIENTS_WITH_OPS,
aggregation_method=self.AGGREGATION_METHOD)
grads = FilterNoneGrad().process(grads)
return grads
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment