Commit c12cf88b authored by Yuxin Wu's avatar Yuxin Wu

trainer & concurrency

parent 8e4c695c
...@@ -17,7 +17,8 @@ from ..utils.concurrency import LoopThread ...@@ -17,7 +17,8 @@ from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils import * from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train'] __all__ = ['SimpleTrainer', 'QueueInputTrainer',
'AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
...@@ -269,7 +270,9 @@ class QueueInputTrainer(Trainer): ...@@ -269,7 +270,9 @@ class QueueInputTrainer(Trainer):
return [self.get_predict_func(input_names, output_names, k) return [self.get_predict_func(input_names, output_names, k)
for k in range(n)] for k in range(n)]
def start_train(config): def AsyncMultiGPUTrainer(config):
tr = QueueInputTrainer(config) return QueueInputTrainer(config, async=True)
tr.train()
def SyncMultiGPUTrainer(config):
return QueueInputTrainer(config)
...@@ -55,28 +55,33 @@ class StoppableThread(threading.Thread): ...@@ -55,28 +55,33 @@ class StoppableThread(threading.Thread):
except queue.Empty: except queue.Empty:
pass pass
class LoopThread(threading.Thread): class LoopThread(StoppableThread):
""" A pausable thread that simply runs a loop""" """ A pausable thread that simply runs a loop"""
def __init__(self, func): def __init__(self, func, pausable=True):
""" """
:param func: the function to run :param func: the function to run
""" """
super(LoopThread, self).__init__() super(LoopThread, self).__init__()
self.func = func self._func = func
self.lock = threading.Lock() self._pausable = pausable
if pausable:
self._lock = threading.Lock()
self.daemon = True self.daemon = True
def run(self): def run(self):
while True: while not self.stopped():
self.lock.acquire() if self._pausable:
self.lock.release() self._lock.acquire()
self.func() self._lock.release()
self._func()
def pause(self): def pause(self):
self.lock.acquire() assert self._pausable
self._lock.acquire()
def resume(self): def resume(self):
self.lock.release() assert self._pausable
self._lock.release()
class DIE(object): class DIE(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment