Commit 4ebdd71d authored by Yuxin Wu's avatar Yuxin Wu

Filter none gradients in FeedfreeTrainer (fix #285)

parent 65317b6b
...@@ -54,11 +54,20 @@ class FilterNoneGrad(GradientProcessor): ...@@ -54,11 +54,20 @@ class FilterNoneGrad(GradientProcessor):
Skip the update and print a warning (instead of crashing), Skip the update and print a warning (instead of crashing),
when the gradient of certain variable is None. when the gradient of certain variable is None.
""" """
def __init__(self, verbose=True):
"""
Args:
verbose (bool): whether to print warning about None gradients.
"""
super(FilterNoneGrad, self).__init__()
self._verbose = verbose
def _process(self, grads): def _process(self, grads):
g = [] g = []
for grad, var in grads: for grad, var in grads:
if grad is None: if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name)) if self._verbose:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else: else:
g.append((grad, var)) g.append((grad, var))
return g return g
......
...@@ -50,7 +50,7 @@ def apply_grad_processors(opt, gradprocs): ...@@ -50,7 +50,7 @@ def apply_grad_processors(opt, gradprocs):
class _ApplyGradientProcessor(ProxyOptimizer): class _ApplyGradientProcessor(ProxyOptimizer):
def __init__(self, opt, gradprocs): def __init__(self, opt, gradprocs):
self._gradprocs = [FilterNoneGrad()] + gradprocs self._gradprocs = gradprocs[:]
super(_ApplyGradientProcessor, self).__init__(opt) super(_ApplyGradientProcessor, self).__init__(opt)
def apply_gradients(self, grads_and_vars, def apply_gradients(self, grads_and_vars,
......
...@@ -102,8 +102,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -102,8 +102,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
# average gradient # average gradient
all_grads = [g for (g, _) in grad_and_vars] all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
grad = tf.multiply( grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_device) tf.add_n(all_grads), 1.0 / nr_device)
new_tower_grads.append((grad, v)) new_tower_grads.append((grad, v))
...@@ -197,6 +195,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -197,6 +195,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
devices=self.raw_devices, devices=self.raw_devices,
var_strategy='replicated', var_strategy='replicated',
vs_names=None) # use the default vs names vs_names=None) # use the default vs names
MultiGPUTrainerBase._check_grad_list(grad_list)
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices) avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from six.moves import zip from six.moves import zip
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput from .input_source import QueueInput, FeedfreeInput
...@@ -41,22 +42,6 @@ class FeedfreeTrainerBase(Trainer): ...@@ -41,22 +42,6 @@ class FeedfreeTrainerBase(Trainer):
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``.""" """ Simply run ``self.train_op``."""
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# else:
# self.cnt += 1
# if self.cnt % 10 == 0:
# # debug-benchmark code:
# run_metadata = tf.RunMetadata()
# self.sess.run([self.train_op],
# options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
# run_metadata=run_metadata
# )
# from tensorflow.python.client import timeline
# trace = timeline.Timeline(step_stats=run_metadata.step_stats)
# trace_file = open('timeline.ctf.json', 'w')
# trace_file.write(trace.generate_chrome_trace_format())
# import sys; sys.exit()
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...@@ -77,6 +62,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -77,6 +62,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
gate_gradients=False, gate_gradients=False,
colocate_gradients_with_ops=True) colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist)) grads = list(zip(grads, varlist))
grads = FilterNoneGrad().process(grads)
return cost, grads return cost, grads
......
...@@ -12,7 +12,7 @@ from ..utils.naming import TOWER_FREEZE_KEYS ...@@ -12,7 +12,7 @@ from ..utils.naming import TOWER_FREEZE_KEYS
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
from .base import Trainer from .base import Trainer
...@@ -101,16 +101,13 @@ class MultiGPUTrainerBase(Trainer): ...@@ -101,16 +101,13 @@ class MultiGPUTrainerBase(Trainer):
return ret return ret
@staticmethod @staticmethod
def check_none_grads(name, grads): def _check_grad_list(grad_list):
# grads: list of N grads """
nones = list(set(grads)) Args:
if None in nones: grad_list: list of list of tuples, shape is Ngpu x Nvar x 2
if len(nones) != 1: """
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(name)) nvars = [len(k) for k in grad_list]
else: assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars)
logger.warn("No Gradient w.r.t {}".format(name))
return False
return True
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
...@@ -175,8 +172,6 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -175,8 +172,6 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars] all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
with tf.device(v.device): # colocate summed grad with var with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply( grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_tower) tf.add_n(all_grads), 1.0 / nr_tower)
...@@ -195,6 +190,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -195,6 +190,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices) self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
# debug tower performance (without update): # debug tower performance (without update):
# ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]] # ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
...@@ -243,8 +239,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -243,8 +239,6 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
grads = [g for g, _ in grad_and_vars] grads = [g for g, _ in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, grads):
continue
summed = nccl.all_sum(grads) summed = nccl.all_sum(grads)
grads_for_a_var = [] grads_for_a_var = []
...@@ -322,8 +316,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): ...@@ -322,8 +316,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices) self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1: if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment