Commit fd21c3b1 authored by Yuxin Wu's avatar Yuxin Wu

fix async training late-binding bug

parent b6a775f4
...@@ -21,7 +21,8 @@ import os ...@@ -21,7 +21,8 @@ import os
sys.path.insert(0, os.path.abspath('../')) sys.path.insert(0, os.path.abspath('../'))
import mock import mock
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk', 'cv2'] MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock() sys.modules[mod_name] = mock.Mock()
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
import threading import threading
import copy import copy
import re import re
import functools
from six.moves import zip from six.moves import zip
from .base import Trainer from .base import Trainer
...@@ -175,7 +176,7 @@ class QueueInputTrainer(Trainer): ...@@ -175,7 +176,7 @@ class QueueInputTrainer(Trainer):
else: else:
grad_list = [self.process_grads(g) for g in grad_list] grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent semantics # sync have consistent effective learning rate
def scale(grads): def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads] return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list) grad_list = map(scale, grad_list)
...@@ -192,7 +193,7 @@ class QueueInputTrainer(Trainer): ...@@ -192,7 +193,7 @@ class QueueInputTrainer(Trainer):
self.threads = [] self.threads = []
for k in range(1, self.config.nr_tower): for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda : self.sess.run([train_op]) f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f) th = LoopThread(f)
th.pause() th.pause()
th.start() th.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment