Commit ef4a15ca authored by Yuxin Wu's avatar Yuxin Wu

atexit for prefetch

parent 8d1ad775
......@@ -90,6 +90,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 30
# prepare session
sess_config = get_default_sess_config()
......
......@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import itertools
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from six.moves import zip
......@@ -80,7 +79,7 @@ class ValidationStatPrinter(ValidationCallback):
stats = np.mean(stats, axis=0)
assert len(stats) == len(self.vars_to_print)
for stat, var in itertools.izip(stats, self.vars_to_print):
for stat, var in zip(stats, self.vars_to_print):
name = var.name.replace(':0', '')
self.trainer.summary_writer.add_summary(create_summary(
'{}_{}'.format(self.prefix, name), stat), self.global_step)
......
......@@ -2,9 +2,10 @@
# File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import multiprocessing
from .base import DataFlow
import multiprocessing
from ..utils.concurrency import ensure_procs_terminate
__all__ = ['PrefetchData']
......@@ -45,6 +46,7 @@ class PrefetchData(DataFlow):
def get_data(self):
queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
ensure_procs_terminate(procs)
[x.start() for x in procs]
end_cnt = 0
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import threading
import copy
import re
from six.moves import zip
......@@ -10,25 +11,10 @@ from six.moves import zip
from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import *
from ..utils.concurrency import EnqueueThread
from ..utils.summary import summary_moving_average
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
def scale_grads(grads, multiplier):
ret = []
for grad, var in grads:
varname = var.name
for regex, val in multiplier:
if re.search(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
class SimpleTrainer(Trainer):
def run_step(self):
data = next(self.data_producer)
......@@ -61,6 +47,34 @@ class SimpleTrainer(Trainer):
summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str)
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = trainer.config.dataset
self.input_vars = raw_input_var
self.op = enqueue_op
self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.daemon = True
def run(self):
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.op.run(feed_dict=feed, session=self.sess)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
self.sess.run(self.close_op)
self.coord.request_stop()
class QueueInputTrainer(Trainer):
"""
......
......@@ -5,10 +5,11 @@
import threading
from contextlib import contextmanager
import tensorflow as tf
import atexit
import weakref
from six.moves import zip
from .naming import *
from . import logger
class StoppableThread(threading.Thread):
def __init__(self):
......@@ -22,31 +23,19 @@ class StoppableThread(threading.Thread):
return self._stop.isSet()
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = trainer.config.dataset
self.input_vars = raw_input_var
self.op = enqueue_op
self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True)
def ensure_proc_terminate(proc):
def stop_proc_by_weak_ref(ref):
proc = ref()
if proc is None:
return
if not proc.is_alive():
return
proc.terminate()
proc.join()
self.daemon = True
assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def run(self):
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.op.run(feed_dict=feed, session=self.sess)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
self.sess.run(self.close_op)
self.coord.request_stop()
def ensure_procs_terminate(procs):
for p in procs:
ensure_proc_terminate(p)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment