Commit a949bfa6 authored by Yuxin Wu's avatar Yuxin Wu

speedup lr_mult=0 by skipping the gradient computation.

parent b5a238a7
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar10_convnet.py # File: cifar10-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
......
...@@ -49,8 +49,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -49,8 +49,7 @@ class PrefetchData(ProxyDataFlow):
x.start() x.start()
def get_data(self): def get_data(self):
tot_cnt = 0 for _ in range(self._size):
for _ in range(tot_cnt):
dp = self.queue.get() dp = self.queue.get()
yield dp yield dp
......
...@@ -61,14 +61,18 @@ class ScaleGradient(GradientProcessor): ...@@ -61,14 +61,18 @@ class ScaleGradient(GradientProcessor):
self.multipliers = multipliers self.multipliers = multipliers
def _process(self, grads): def _process(self, grads):
# TODO use None for zero can speed up (or not)?
ret = [] ret = []
for grad, var in grads: for grad, var in grads:
varname = var.op.name varname = var.op.name
for regex, val in self.multipliers: for regex, val in self.multipliers:
if re.search(regex, varname): # always match against the whole name
if not regex.endswith('$'):
regex = regex + '$'
if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname)) logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var)) if val != 0: # skip zero to speed up
ret.append((grad * val, var))
break break
else: else:
ret.append((grad, var)) ret.append((grad, var))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment