Commit 4fa2837e authored by Yuxin Wu's avatar Yuxin Wu

add graph.finalize

parent dcf55733
...@@ -171,6 +171,8 @@ def start_train(config): ...@@ -171,6 +171,8 @@ def start_train(config):
try: try:
logger.info("Start training with global_step={}".format(get_global_step())) logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train() callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for step in tqdm.trange( for step in tqdm.trange(
...@@ -186,7 +188,7 @@ def start_train(config): ...@@ -186,7 +188,7 @@ def start_train(config):
raise raise
finally: finally:
coord.request_stop() coord.request_stop()
input_queue.close(cancel_pending_enqueues=True) # Do I need to run queue.close
callbacks.after_train() callbacks.after_train()
sess.close() sess.close()
...@@ -33,6 +33,7 @@ class EnqueueThread(threading.Thread): ...@@ -33,6 +33,7 @@ class EnqueueThread(threading.Thread):
self.dataflow = dataflow self.dataflow = dataflow
self.op = enqueue_op self.op = enqueue_op
self.queue = queue self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.daemon = True self.daemon = True
...@@ -49,5 +50,5 @@ class EnqueueThread(threading.Thread): ...@@ -49,5 +50,5 @@ class EnqueueThread(threading.Thread):
pass pass
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
self.queue.close(cancel_pending_enqueues=True) self.sess.run(self.close_op)
self.coord.request_stop() self.coord.request_stop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment