Commit 3f743301 authored by Yuxin Wu's avatar Yuxin Wu

async training.

parent 08821b55
...@@ -48,6 +48,7 @@ class Trainer(object): ...@@ -48,6 +48,7 @@ class Trainer(object):
@abstractmethod @abstractmethod
def _trigger_epoch(self): def _trigger_epoch(self):
""" This is called right after all steps in an epoch are finished"""
pass pass
def _init_summary(self): def _init_summary(self):
...@@ -94,7 +95,7 @@ class Trainer(object): ...@@ -94,7 +95,7 @@ class Trainer(object):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() self.run_step()
callbacks.trigger_step() #callbacks.trigger_step() # not useful?
self.global_step += 1 self.global_step += 1
self.trigger_epoch() self.trigger_epoch()
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
......
...@@ -11,6 +11,7 @@ from six.moves import zip ...@@ -11,6 +11,7 @@ from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import * from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils import * from ..tfutils import *
...@@ -79,13 +80,14 @@ class EnqueueThread(threading.Thread): ...@@ -79,13 +80,14 @@ class EnqueueThread(threading.Thread):
finally: finally:
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" """
Trainer which builds a FIFO queue for input. Trainer which builds a FIFO queue for input.
Support multi GPU. Support multi GPU.
""" """
def __init__(self, config, input_queue=None): def __init__(self, config, input_queue=None, async=False):
""" """
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
...@@ -98,6 +100,9 @@ class QueueInputTrainer(Trainer): ...@@ -98,6 +100,9 @@ class QueueInputTrainer(Trainer):
100, [x.dtype for x in self.input_vars], name='input_queue') 100, [x.dtype for x in self.input_vars], name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
self.async = async
if self.async:
assert self.config.nr_tower > 1
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -122,14 +127,15 @@ class QueueInputTrainer(Trainer): ...@@ -122,14 +127,15 @@ class QueueInputTrainer(Trainer):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
return ret return ret
def _single_tower_grad_cost(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower case""" """ Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs() model_inputs = self._get_model_inputs()
cost_var = self.model.get_cost(model_inputs, is_training=True) cost_var = self.model.get_cost(model_inputs, is_training=True)
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
return (grads, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
return grads
def _multi_tower_grad_cost(self): def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower)) logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device # to avoid repeated summary from each device
...@@ -140,6 +146,7 @@ class QueueInputTrainer(Trainer): ...@@ -140,6 +146,7 @@ class QueueInputTrainer(Trainer):
for i in range(self.config.nr_tower): for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue model_inputs = self._get_model_inputs() # each tower dequeue from input queue
cost_var = self.model.get_cost(model_inputs, is_training=True) # build tower cost_var = self.model.get_cost(model_inputs, is_training=True) # build tower
...@@ -148,30 +155,49 @@ class QueueInputTrainer(Trainer): ...@@ -148,30 +155,49 @@ class QueueInputTrainer(Trainer):
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0)) self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0: if i == 0:
cost_var_t0 = cost_var tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in collect_dedup: for k in collect_dedup:
kept_summaries[k] = copy.copy(tf.get_collection(k)) kept_summaries[k] = copy.copy(tf.get_collection(k))
logger.info("Graph built for tower {}.".format(i))
for k in collect_dedup: for k in collect_dedup:
del tf.get_collection_ref(k)[:] del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(kept_summaries[k]) tf.get_collection_ref(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list) return grad_list
return (grads, cost_var_t0)
def train(self): def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars) enqueue_op = self.input_queue.enqueue(self.input_vars)
grads, cost_var = self._single_tower_grad_cost() \ if self.config.nr_tower > 1:
if self.config.nr_tower == 0 else self._multi_tower_grad_cost() grad_list = self._multi_tower_grads()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) if not self.async:
avg_maintain_op = summary_moving_average() grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
grads = self.process_grads(grads) else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grads = grad_list[0] # use grad from the first tower for routinely stuff
else:
grads = self._single_tower_grad()
grads = self.process_grads(grads)
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) summary_moving_average())
if self.async:
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda : self.sess.run([train_op])
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
...@@ -183,14 +209,25 @@ class QueueInputTrainer(Trainer): ...@@ -183,14 +209,25 @@ class QueueInputTrainer(Trainer):
self.input_th.start() self.input_th.start()
def run_step(self): def run_step(self):
if self.async:
if not self.async_running:
self.async_running = True
for th in self.threads: # resume all threads
th.resume()
self.sess.run([self.train_op]) # faster since train_op return None self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self): def _trigger_epoch(self):
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
if self.async:
self.async_running = False
for th in self.threads:
th.pause()
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def start_train(config): def start_train(config):
tr = QueueInputTrainer(config) tr = QueueInputTrainer(config)
tr.train() tr.train()
...@@ -9,7 +9,7 @@ import atexit ...@@ -9,7 +9,7 @@ import atexit
import bisect import bisect
import weakref import weakref
__all__ = ['StoppableThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE'] 'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
...@@ -23,6 +23,29 @@ class StoppableThread(threading.Thread): ...@@ -23,6 +23,29 @@ class StoppableThread(threading.Thread):
def stopped(self): def stopped(self):
return self._stop.isSet() return self._stop.isSet()
class LoopThread(threading.Thread):
""" A pausable thread that simply runs a loop"""
def __init__(self, func):
"""
:param func: the function to run
"""
super(LoopThread, self).__init__()
self.func = func
self.lock = threading.Lock()
self.daemon = True
def run(self):
while True:
self.lock.acquire()
self.lock.release()
self.func()
def pause(self):
self.lock.acquire()
def resume(self):
self.lock.release()
class DIE(object): class DIE(object):
""" A placeholder class indicating end of queue """ """ A placeholder class indicating end of queue """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment