Commit 3629d9ca authored by ppwwyyxx's avatar ppwwyyxx

use daemon thread

parent e8601aec
...@@ -18,6 +18,21 @@ def prepare(): ...@@ -18,6 +18,21 @@ def prepare():
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
def get_train_op(optimizer, cost_var):
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
avg_maintain_op = summary_moving_average(cost_var)
# maintain average in each step
with tf.control_dependencies([avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var)
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
return optimizer.apply_gradients(grads, global_step_var)
def start_train(config): def start_train(config):
""" """
Start training with the given config Start training with the given config
...@@ -58,27 +73,14 @@ def start_train(config): ...@@ -58,27 +73,14 @@ def start_train(config):
output_vars, cost_var = get_model_func(model_inputs, is_training=True) output_vars, cost_var = get_model_func(model_inputs, is_training=True)
# build graph # build graph
G = tf.get_default_graph() tf.add_to_collection(FORWARD_FUNC_KEY, get_model_func)
G.add_to_collection(FORWARD_FUNC_KEY, get_model_func)
for v in input_vars: for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v) tf.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars: for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v) tf.add_to_collection(OUTPUT_VARS_KEY, v)
describe_model() describe_model()
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME) train_op = get_train_op(optimizer, cost_var)
avg_maintain_op = summary_moving_average(cost_var)
# maintain average in each step
with tf.control_dependencies([avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var)
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
train_op = optimizer.apply_gradients(grads, global_step_var)
sess = tf.Session(config=sess_config) sess = tf.Session(config=sess_config)
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
...@@ -90,11 +92,11 @@ def start_train(config): ...@@ -90,11 +92,11 @@ def start_train(config):
# a thread that keeps filling the queue # a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, dataset_train) input_th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
model_th = tf.train.start_queue_runners( model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=False) sess=sess, coord=coord, daemon=True, start=True)
input_th.start()
with sess.as_default(), \ with sess.as_default(), \
coordinator_guard( coordinator_guard(sess, coord):
sess, coord, [input_th] + model_th, input_queue):
callbacks.before_train() callbacks.before_train()
for epoch in xrange(1, max_epoch): for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
......
...@@ -32,6 +32,7 @@ class EnqueueThread(threading.Thread): ...@@ -32,6 +32,7 @@ class EnqueueThread(threading.Thread):
self.input_vars = sess.graph.get_collection(INPUT_VARS_KEY) self.input_vars = sess.graph.get_collection(INPUT_VARS_KEY)
self.dataflow = dataflow self.dataflow = dataflow
self.op = enqueue_op self.op = enqueue_op
self.daemon = True
def run(self): def run(self):
try: try:
...@@ -49,20 +50,11 @@ class EnqueueThread(threading.Thread): ...@@ -49,20 +50,11 @@ class EnqueueThread(threading.Thread):
self.coord.request_stop() self.coord.request_stop()
@contextmanager @contextmanager
def coordinator_guard(sess, coord, threads, queue): def coordinator_guard(sess, coord):
"""
Context manager to make sure that:
queue is closed
threads are joined
"""
for th in threads:
th.start()
try: try:
yield yield
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
raise raise
finally: finally:
coord.request_stop() coord.request_stop()
sess.run( sess.close()
queue.close(cancel_pending_enqueues=True))
coord.join(threads)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment