Commit 82b418fd authored by Yuxin Wu's avatar Yuxin Wu

small changes. add summary about queue size

parent 1898cd3c
...@@ -49,7 +49,7 @@ class SimulatorProcessBase(mp.Process): ...@@ -49,7 +49,7 @@ class SimulatorProcessBase(mp.Process):
def __init__(self, idx): def __init__(self, idx):
super(SimulatorProcessBase, self).__init__() super(SimulatorProcessBase, self).__init__()
self.idx = int(idx) self.idx = int(idx)
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8') self.name = self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
@abstractmethod @abstractmethod
def _build_player(self): def _build_player(self):
...@@ -111,6 +111,7 @@ class SimulatorMaster(threading.Thread): ...@@ -111,6 +111,7 @@ class SimulatorMaster(threading.Thread):
def __init__(self, pipe_c2s, pipe_s2c): def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
self.daemon = True self.daemon = True
self.name = 'SimulatorMaster'
self.context = zmq.Context() self.context = zmq.Context()
......
...@@ -20,6 +20,7 @@ class StartProcOrThread(Callback): ...@@ -20,6 +20,7 @@ class StartProcOrThread(Callback):
self._procs_threads = procs_threads self._procs_threads = procs_threads
def _before_train(self): def _before_train(self):
logger.info("Starting all threads & procs ...") logger.info("Starting threads & procs: " + \
' .'.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads) start_proc_mask_signal(self._procs_threads)
...@@ -57,6 +57,8 @@ class Callbacks(Callback): ...@@ -57,6 +57,8 @@ class Callbacks(Callback):
cbs.remove(sp) cbs.remove(sp)
cbs.append(sp) cbs.append(sp)
break break
else:
raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!")
self.cbs = cbs self.cbs = cbs
......
...@@ -9,6 +9,7 @@ import json ...@@ -9,6 +9,7 @@ import json
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils.common import get_global_step
__all__ = ['StatHolder', 'StatPrinter', 'SendStat'] __all__ = ['StatHolder', 'StatPrinter', 'SendStat']
...@@ -102,11 +103,15 @@ class StatPrinter(Callback): ...@@ -102,11 +103,15 @@ class StatPrinter(Callback):
self.print_tag = print_tag self.print_tag = print_tag
def _before_train(self): def _before_train(self):
self.trainer.stat_holder.set_print_tag(self.print_tag) self._stat_holder = self.trainer.stat_holder
self.trainer.stat_holder.add_blacklist_tag(['global_step', 'epoch_num']) self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
def _trigger_epoch(self): def _trigger_epoch(self):
self.trainer.stat_holder.finalize() # by default, add this two stat
self._stat_holder.add_stat('global_step', get_global_step())
self._stat_holder.add_stat('epoch_num', self.epoch_num)
self._stat_holder.finalize()
class SendStat(Callback): class SendStat(Callback):
""" """
......
...@@ -42,7 +42,8 @@ class Trainer(object): ...@@ -42,7 +42,8 @@ class Trainer(object):
self.config = config self.config = config
self.model = config.model self.model = config.model
self.model.get_input_vars() # ensure they are present self.model.get_input_vars() # ensure they are present
self.init_session_and_coord() self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
@abstractmethod @abstractmethod
def train(self): def train(self):
...@@ -67,10 +68,6 @@ class Trainer(object): ...@@ -67,10 +68,6 @@ class Trainer(object):
return [self.get_predict_func(input_names, output_names) for k in range(n)] return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
# by default, add this two stat
self.stat_holder.add_stat('global_step', get_global_step())
self.stat_holder.add_stat('epoch_num', self.epoch_num)
# trigger subclass # trigger subclass
self._trigger_epoch() self._trigger_epoch()
# trigger callbacks # trigger callbacks
...@@ -92,11 +89,10 @@ class Trainer(object): ...@@ -92,11 +89,10 @@ class Trainer(object):
def write_scalar_summary(self, name, val): def write_scalar_summary(self, name, val):
self.summary_writer.add_summary( self.summary_writer.add_summary(
create_summary(name, val), create_summary(name, val), get_global_step())
get_global_step())
self.stat_holder.add_stat(name, val) self.stat_holder.add_stat(name, val)
def finalize_graph(self): def finalize(self):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
...@@ -111,22 +107,23 @@ class Trainer(object): ...@@ -111,22 +107,23 @@ class Trainer(object):
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables()) self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
tf.train.start_queue_runners( tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
def main_loop(self): def main_loop(self):
self.finalize_graph() self.finalize()
callbacks = self.config.callbacks callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
try: try:
callbacks.before_train() callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step())) logger.info("Start training with global_step={}".format(get_global_step()))
for self.epoch_num in range( for epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1): self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
'Epoch {} (global_step {})'.format( 'Epoch {} (global_step {})'.format(
self.epoch_num, get_global_step() + self.config.step_per_epoch)): epoch_num, get_global_step() + self.config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)): **get_tqdm_kwargs(leave=True)):
...@@ -137,18 +134,12 @@ class Trainer(object): ...@@ -137,18 +134,12 @@ class Trainer(object):
self.trigger_epoch() self.trigger_epoch()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
except (KeyboardInterrupt, Exception):
raise
finally: finally:
callbacks.after_train() callbacks.after_train()
self.coord.request_stop() self.coord.request_stop()
self.summary_writer.close() self.summary_writer.close()
self.sess.close() self.sess.close()
def init_session_and_coord(self):
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
def process_grads(self, grads): def process_grads(self, grads):
g = [] g = []
for grad, var in grads: for grad, var in grads:
......
...@@ -98,6 +98,7 @@ class SimpleTrainer(Trainer): ...@@ -98,6 +98,7 @@ class SimpleTrainer(Trainer):
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer): def __init__(self, trainer):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.sess = trainer.sess self.sess = trainer.sess
self.coord = trainer.coord self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1) self.dataflow = RepeatedData(trainer.config.dataset, -1)
...@@ -117,9 +118,6 @@ class EnqueueThread(threading.Thread): ...@@ -117,9 +118,6 @@ class EnqueueThread(threading.Thread):
try: try:
while True: while True:
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
...@@ -150,7 +148,6 @@ class QueueInputTrainer(Trainer): ...@@ -150,7 +148,6 @@ class QueueInputTrainer(Trainer):
""" """
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if input_queue is None: if input_queue is None:
self.input_queue = tf.FIFOQueue( self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars], name='input_queue') 50, [x.dtype for x in self.input_vars], name='input_queue')
...@@ -174,6 +171,8 @@ class QueueInputTrainer(Trainer): ...@@ -174,6 +171,8 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower""" """ Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_dequeued_inputs() self.dequed_inputs = model_inputs = self._get_dequeued_inputs()
add_moving_summary(tf.cast(
self.input_queue.size(), tf.float32, name='input-queue-size'))
# test the overhead of queue # test the overhead of queue
#with tf.device('/gpu:0'): #with tf.device('/gpu:0'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment