Commit 3f743301 authored by Yuxin Wu's avatar Yuxin Wu

async training.

parent 08821b55
......@@ -48,6 +48,7 @@ class Trainer(object):
@abstractmethod
def _trigger_epoch(self):
""" This is called right after all steps in an epoch are finished"""
pass
def _init_summary(self):
......@@ -94,7 +95,7 @@ class Trainer(object):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
#callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
......
......@@ -11,6 +11,7 @@ from six.moves import zip
from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils import *
......@@ -79,13 +80,14 @@ class EnqueueThread(threading.Thread):
finally:
logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer):
"""
Trainer which builds a FIFO queue for input.
Support multi GPU.
"""
def __init__(self, config, input_queue=None):
def __init__(self, config, input_queue=None, async=False):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
......@@ -98,6 +100,9 @@ class QueueInputTrainer(Trainer):
100, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
self.async = async
if self.async:
assert self.config.nr_tower > 1
@staticmethod
def _average_grads(tower_grads):
......@@ -122,14 +127,15 @@ class QueueInputTrainer(Trainer):
qv.set_shape(v.get_shape())
return ret
def _single_tower_grad_cost(self):
def _single_tower_grad(self):
""" Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs()
cost_var = self.model.get_cost(model_inputs, is_training=True)
grads = self.config.optimizer.compute_gradients(cost_var)
return (grads, cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
return grads
def _multi_tower_grad_cost(self):
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device
......@@ -140,6 +146,7 @@ class QueueInputTrainer(Trainer):
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
cost_var = self.model.get_cost(model_inputs, is_training=True) # build tower
......@@ -148,30 +155,49 @@ class QueueInputTrainer(Trainer):
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0:
cost_var_t0 = cost_var
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
for k in collect_dedup:
kept_summaries[k] = copy.copy(tf.get_collection(k))
logger.info("Graph built for tower {}.".format(i))
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list)
return (grads, cost_var_t0)
return grad_list
def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars)
grads, cost_var = self._single_tower_grad_cost() \
if self.config.nr_tower == 0 else self._multi_tower_grad_cost()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
avg_maintain_op = summary_moving_average()
grads = self.process_grads(grads)
if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads()
if not self.async:
grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grads = grad_list[0] # use grad from the first tower for routinely stuff
else:
grads = self._single_tower_grad()
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
summary_moving_average())
if self.async:
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda : self.sess.run([train_op])
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False
self.init_session_and_coord()
# create a thread that keeps filling the queue
......@@ -183,14 +209,25 @@ class QueueInputTrainer(Trainer):
self.input_th.start()
def run_step(self):
if self.async:
if not self.async_running:
self.async_running = True
for th in self.threads: # resume all threads
th.resume()
self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self):
# note that summary_op will take a data from the queue
if self.async:
self.async_running = False
for th in self.threads:
th.pause()
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def start_train(config):
tr = QueueInputTrainer(config)
tr.train()
......@@ -9,7 +9,7 @@ import atexit
import bisect
import weakref
__all__ = ['StoppableThread', 'ensure_proc_terminate',
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
class StoppableThread(threading.Thread):
......@@ -23,6 +23,29 @@ class StoppableThread(threading.Thread):
def stopped(self):
return self._stop.isSet()
class LoopThread(threading.Thread):
""" A pausable thread that simply runs a loop"""
def __init__(self, func):
"""
:param func: the function to run
"""
super(LoopThread, self).__init__()
self.func = func
self.lock = threading.Lock()
self.daemon = True
def run(self):
while True:
self.lock.acquire()
self.lock.release()
self.func()
def pause(self):
self.lock.acquire()
def resume(self):
self.lock.release()
class DIE(object):
""" A placeholder class indicating end of queue """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment