Commit 12bf21bc authored by Yuxin Wu's avatar Yuxin Wu

asynctrainer global counter

parent 6ecaab67
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import itertools, re
from six.moves import zip, range from six.moves import zip, range
from ..utils import * from ..utils import *
...@@ -104,30 +105,45 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -104,30 +105,45 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
describe_model() describe_model()
self._start_async_threads(grad_list)
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.main_loop()
def _start_async_threads(self, grad_list):
# prepare train_op for the rest of the towers # prepare train_op for the rest of the towers
# itertools.count is atomic w.r.t. python threads
self.async_step_counter = itertools.count()
self.training_threads = [] self.training_threads = []
for k in range(1, self.config.nr_tower): for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding def f(op=train_op): # avoid late-binding
self.sess.run([op])
self.async_step_counter.next()
th = LoopThread(f) th = LoopThread(f)
th.pause() th.pause()
th.start() th.start()
self.training_threads.append(th) self.training_threads.append(th)
self.async_running = False self.async_running = False
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.main_loop()
def run_step(self): def run_step(self):
if not self.async_running: if not self.async_running:
self.async_running = True self.async_running = True
for th in self.training_threads: # resume all threads for th in self.training_threads: # resume all threads
th.resume() th.resume()
self.async_step_counter.next()
super(AsyncMultiGPUTrainer, self).run_step() super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self): def _trigger_epoch(self):
self.async_running = False self.async_running = False
for th in self.training_threads: for th in self.training_threads:
th.pause() th.pause()
try:
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary(
'async_global_step', async_step_total_cnt)
except:
pass
super(AsyncMultiGPUTrainer, self)._trigger_epoch() super(AsyncMultiGPUTrainer, self)._trigger_epoch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment