Commit ef4a15ca authored by Yuxin Wu's avatar Yuxin Wu

atexit for prefetch

parent 8d1ad775
...@@ -90,6 +90,7 @@ def get_config(): ...@@ -90,6 +90,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 30
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import itertools
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six.moves import zip from six.moves import zip
...@@ -80,7 +79,7 @@ class ValidationStatPrinter(ValidationCallback): ...@@ -80,7 +79,7 @@ class ValidationStatPrinter(ValidationCallback):
stats = np.mean(stats, axis=0) stats = np.mean(stats, axis=0)
assert len(stats) == len(self.vars_to_print) assert len(stats) == len(self.vars_to_print)
for stat, var in itertools.izip(stats, self.vars_to_print): for stat, var in zip(stats, self.vars_to_print):
name = var.name.replace(':0', '') name = var.name.replace(':0', '')
self.trainer.summary_writer.add_summary(create_summary( self.trainer.summary_writer.add_summary(create_summary(
'{}_{}'.format(self.prefix, name), stat), self.global_step) '{}_{}'.format(self.prefix, name), stat), self.global_step)
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
# File: prefetch.py # File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import multiprocessing
from .base import DataFlow from .base import DataFlow
import multiprocessing from ..utils.concurrency import ensure_procs_terminate
__all__ = ['PrefetchData'] __all__ = ['PrefetchData']
...@@ -45,6 +46,7 @@ class PrefetchData(DataFlow): ...@@ -45,6 +46,7 @@ class PrefetchData(DataFlow):
def get_data(self): def get_data(self):
queue = multiprocessing.Queue(self.nr_prefetch) queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)] procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
ensure_procs_terminate(procs)
[x.start() for x in procs] [x.start() for x in procs]
end_cnt = 0 end_cnt = 0
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import threading
import copy import copy
import re import re
from six.moves import zip from six.moves import zip
...@@ -10,25 +11,10 @@ from six.moves import zip ...@@ -10,25 +11,10 @@ from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import * from ..utils import *
from ..utils.concurrency import EnqueueThread
from ..utils.summary import summary_moving_average from ..utils.summary import summary_moving_average
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train'] __all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
def scale_grads(grads, multiplier):
ret = []
for grad, var in grads:
varname = var.name
for regex, val in multiplier:
if re.search(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
data = next(self.data_producer) data = next(self.data_producer)
...@@ -61,6 +47,34 @@ class SimpleTrainer(Trainer): ...@@ -61,6 +47,34 @@ class SimpleTrainer(Trainer):
summary_str = self.summary_op.eval(feed_dict=feed) summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str) self._process_summary(summary_str)
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = trainer.config.dataset
self.input_vars = raw_input_var
self.op = enqueue_op
self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.daemon = True
def run(self):
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.op.run(feed_dict=feed, session=self.sess)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
self.sess.run(self.close_op)
self.coord.request_stop()
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" """
......
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
import atexit
import weakref
from six.moves import zip from six.moves import zip
from .naming import * from .naming import *
from . import logger
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
def __init__(self): def __init__(self):
...@@ -22,31 +23,19 @@ class StoppableThread(threading.Thread): ...@@ -22,31 +23,19 @@ class StoppableThread(threading.Thread):
return self._stop.isSet() return self._stop.isSet()
class EnqueueThread(threading.Thread): def ensure_proc_terminate(proc):
def __init__(self, trainer, queue, enqueue_op, raw_input_var): def stop_proc_by_weak_ref(ref):
super(EnqueueThread, self).__init__() proc = ref()
self.sess = trainer.sess if proc is None:
self.coord = trainer.coord return
self.dataflow = trainer.config.dataset if not proc.is_alive():
return
self.input_vars = raw_input_var proc.terminate()
self.op = enqueue_op proc.join()
self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True) assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
self.daemon = True
def ensure_procs_terminate(procs):
def run(self): for p in procs:
try: ensure_proc_terminate(p)
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.op.run(feed_dict=feed, session=self.sess)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
self.sess.run(self.close_op)
self.coord.request_stop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment