Commit fd21c3b1 authored by Yuxin Wu's avatar Yuxin Wu

fix async training late-binding bug

parent b6a775f4
......@@ -21,7 +21,8 @@ import os
sys.path.insert(0, os.path.abspath('../'))
import mock
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk', 'cv2']
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
......
......@@ -6,6 +6,7 @@ import tensorflow as tf
import threading
import copy
import re
import functools
from six.moves import zip
from .base import Trainer
......@@ -175,7 +176,7 @@ class QueueInputTrainer(Trainer):
else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
......@@ -192,7 +193,7 @@ class QueueInputTrainer(Trainer):
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda : self.sess.run([train_op])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment