Commit a949bfa6 authored by Yuxin Wu's avatar Yuxin Wu

speedup lr_mult=0 by skipping the gradient computation.

parent b5a238a7
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: cifar10_convnet.py
# File: cifar10-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......
......@@ -49,8 +49,7 @@ class PrefetchData(ProxyDataFlow):
x.start()
def get_data(self):
tot_cnt = 0
for _ in range(tot_cnt):
for _ in range(self._size):
dp = self.queue.get()
yield dp
......
......@@ -61,14 +61,18 @@ class ScaleGradient(GradientProcessor):
self.multipliers = multipliers
def _process(self, grads):
# TODO use None for zero can speed up (or not)?
ret = []
for grad, var in grads:
varname = var.op.name
for regex, val in self.multipliers:
if re.search(regex, varname):
# always match against the whole name
if not regex.endswith('$'):
regex = regex + '$'
if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
if val != 0: # skip zero to speed up
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment