Commit 3898d354 authored by Yuxin Wu's avatar Yuxin Wu

add some small assertions

parent b48cd060
...@@ -67,7 +67,7 @@ class GlobalNormClip(GradientProcessor): ...@@ -67,7 +67,7 @@ class GlobalNormClip(GradientProcessor):
Args: Args:
global_norm(float): the threshold to clip with. global_norm(float): the threshold to clip with.
""" """
self._norm = global_norm self._norm = float(global_norm)
def _process(self, grads): def _process(self, grads):
g = [k[0] for k in grads] g = [k[0] for k in grads]
...@@ -176,6 +176,7 @@ class ScaleGradient(MapGradient): ...@@ -176,6 +176,7 @@ class ScaleGradient(MapGradient):
if not isinstance(multipliers, list): if not isinstance(multipliers, list):
multipliers = [multipliers] multipliers = [multipliers]
self.multipliers = multipliers self.multipliers = multipliers
assert log in [True, False], log
self._log = log self._log = log
super(ScaleGradient, self).__init__(self._mapper) super(ScaleGradient, self).__init__(self._mapper)
......
...@@ -16,6 +16,7 @@ class ProxyOptimizer(tf.train.Optimizer): ...@@ -16,6 +16,7 @@ class ProxyOptimizer(tf.train.Optimizer):
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer` A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
""" """
def __init__(self, opt, name='ProxyOptimizer'): def __init__(self, opt, name='ProxyOptimizer'):
assert isinstance(opt, tf.train.Optimizer), opt
super(ProxyOptimizer, self).__init__(False, name) super(ProxyOptimizer, self).__init__(False, name)
self._opt = opt self._opt = opt
...@@ -44,6 +45,7 @@ def apply_grad_processors(opt, gradprocs): ...@@ -44,6 +45,7 @@ def apply_grad_processors(opt, gradprocs):
a :class:`tf.train.Optimizer` instance which runs the gradient a :class:`tf.train.Optimizer` instance which runs the gradient
processors before updating the variables. processors before updating the variables.
""" """
assert isinstance(gradprocs, (list, tuple)), gradprocs
class _ApplyGradientProcessor(ProxyOptimizer): class _ApplyGradientProcessor(ProxyOptimizer):
def __init__(self, opt, gradprocs): def __init__(self, opt, gradprocs):
......
...@@ -82,7 +82,7 @@ class MultiGPUTrainerBase(Trainer): ...@@ -82,7 +82,7 @@ class MultiGPUTrainerBase(Trainer):
if idx == t: if idx == t:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
else: else:
logger.info("Building graph for training tower {} on device {}...".format(idx, t)) logger.info("Building graph for training tower {} on device {}...".format(idx, device))
ret.append(func()) ret.append(func())
...@@ -264,7 +264,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -264,7 +264,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
run_before=True, run_as_trigger=True)) run_before=True, run_as_trigger=True))
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
@staticmethod @staticmethod
def get_post_init_ops(): def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs. # Copy initialized values for variables on GPU 0 to other GPUs.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment