Commit 111cb09b authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'better-a3c'

parents 8e4c695c c12cf88b
......@@ -17,7 +17,8 @@ from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
__all__ = ['SimpleTrainer', 'QueueInputTrainer',
'AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class SimpleTrainer(Trainer):
def run_step(self):
......@@ -269,7 +270,9 @@ class QueueInputTrainer(Trainer):
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
def start_train(config):
tr = QueueInputTrainer(config)
tr.train()
def AsyncMultiGPUTrainer(config):
return QueueInputTrainer(config, async=True)
def SyncMultiGPUTrainer(config):
return QueueInputTrainer(config)
......@@ -55,28 +55,33 @@ class StoppableThread(threading.Thread):
except queue.Empty:
pass
class LoopThread(threading.Thread):
class LoopThread(StoppableThread):
""" A pausable thread that simply runs a loop"""
def __init__(self, func):
def __init__(self, func, pausable=True):
"""
:param func: the function to run
"""
super(LoopThread, self).__init__()
self.func = func
self.lock = threading.Lock()
self._func = func
self._pausable = pausable
if pausable:
self._lock = threading.Lock()
self.daemon = True
def run(self):
while True:
self.lock.acquire()
self.lock.release()
self.func()
while not self.stopped():
if self._pausable:
self._lock.acquire()
self._lock.release()
self._func()
def pause(self):
self.lock.acquire()
assert self._pausable
self._lock.acquire()
def resume(self):
self.lock.release()
assert self._pausable
self._lock.release()
class DIE(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment