Commit 11c46a71 authored by Yuxin Wu's avatar Yuxin Wu

subproc_call & fix multiprocess sigint problem

parent b852d652
......@@ -125,7 +125,7 @@ class ExpReplay(DataFlow, Callback):
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None,
new_experience_per_step=1,
update_frequency=1,
history_len=1
):
"""
......@@ -196,7 +196,7 @@ class ExpReplay(DataFlow, Callback):
#view_state(exp[0], exp[1])
yield self._process_batch(batch_exp)
for _ in range(self.new_experience_per_step):
for _ in range(self.update_frequency):
self._populate_exp()
def sample_one(self):
......
......@@ -2,12 +2,13 @@
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import signal
import re
from six.moves import range
import tqdm
import re
import tensorflow as tf
from .config import TrainConfig
from ..utils import *
from ..callbacks import StatHolder
......@@ -135,8 +136,12 @@ class Trainer(object):
"""
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
# avoid sigint get handled by other processes
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
for k in self.extra_threads_procs:
k.start()
signal.signal(signal.SIGINT, sigint_handler)
def process_grads(self, grads):
......
......@@ -8,6 +8,11 @@ import multiprocessing
import atexit
import bisect
import weakref
import six
if six.PY2:
import subprocess32 as subprocess
else:
import subprocess
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
......@@ -69,6 +74,18 @@ def ensure_proc_terminate(proc):
assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def subproc_call(cmd, timeout=None):
try:
output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT,
shell=True, timeout=timeout)
return output
except subprocess.TimeoutExpired as e:
logger.warn("Timeout in evaluation!")
logger.warn(e.output)
except subprocess.CalledProcessError as e:
logger.warn("Evaluation script failed: {}".format(e.returncode))
logger.warn(e.output)
class OrderedContainer(object):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment