Commit 11c46a71 authored by Yuxin Wu's avatar Yuxin Wu

subproc_call & fix multiprocess sigint problem

parent b852d652
...@@ -125,7 +125,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -125,7 +125,7 @@ class ExpReplay(DataFlow, Callback):
end_exploration=0.1, end_exploration=0.1,
exploration_epoch_anneal=0.002, exploration_epoch_anneal=0.002,
reward_clip=None, reward_clip=None,
new_experience_per_step=1, update_frequency=1,
history_len=1 history_len=1
): ):
""" """
...@@ -196,7 +196,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -196,7 +196,7 @@ class ExpReplay(DataFlow, Callback):
#view_state(exp[0], exp[1]) #view_state(exp[0], exp[1])
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
for _ in range(self.new_experience_per_step): for _ in range(self.update_frequency):
self._populate_exp() self._populate_exp()
def sample_one(self): def sample_one(self):
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import signal
import re
from six.moves import range from six.moves import range
import tqdm import tqdm
import re
import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import * from ..utils import *
from ..callbacks import StatHolder from ..callbacks import StatHolder
...@@ -135,8 +136,12 @@ class Trainer(object): ...@@ -135,8 +136,12 @@ class Trainer(object):
""" """
tf.train.start_queue_runners( tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
# avoid sigint get handled by other processes
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
for k in self.extra_threads_procs: for k in self.extra_threads_procs:
k.start() k.start()
signal.signal(signal.SIGINT, sigint_handler)
def process_grads(self, grads): def process_grads(self, grads):
......
...@@ -8,6 +8,11 @@ import multiprocessing ...@@ -8,6 +8,11 @@ import multiprocessing
import atexit import atexit
import bisect import bisect
import weakref import weakref
import six
if six.PY2:
import subprocess32 as subprocess
else:
import subprocess
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE'] 'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
...@@ -69,6 +74,18 @@ def ensure_proc_terminate(proc): ...@@ -69,6 +74,18 @@ def ensure_proc_terminate(proc):
assert isinstance(proc, multiprocessing.Process) assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def subproc_call(cmd, timeout=None):
try:
output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT,
shell=True, timeout=timeout)
return output
except subprocess.TimeoutExpired as e:
logger.warn("Timeout in evaluation!")
logger.warn(e.output)
except subprocess.CalledProcessError as e:
logger.warn("Evaluation script failed: {}".format(e.returncode))
logger.warn(e.output)
class OrderedContainer(object): class OrderedContainer(object):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment