Commit 322449d2 authored by Yuxin Wu's avatar Yuxin Wu

fix saver/thread order bug

parent 7237a1c8
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import copy
from .base import DataFlow, ProxyDataFlow
from ..utils import *
......
......@@ -62,14 +62,19 @@ class Trainer(object):
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
# some final operations that might modify the graph
self._init_summary()
get_global_step_var()
callbacks = self.config.callbacks
callbacks.before_train(self)
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
self._start_all_threads()
with self.sess.as_default():
try:
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
......@@ -97,10 +102,12 @@ class Trainer(object):
def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
def _start_all_threads(self):
"""
Run all threads before starting training
"""
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
......@@ -142,12 +142,14 @@ class QueueInputTrainer(Trainer):
avg_maintain_op)
self.init_session_and_coord()
# create a thread that keeps filling the queue
input_th = EnqueueThread(self, input_queue, enqueue_op, input_vars)
input_th.start()
self.input_th = EnqueueThread(self, input_queue, enqueue_op, input_vars)
self.main_loop()
def _start_all_threads(self):
super(QueueInputTrainer, self)._start_all_threads()
self.input_th.start()
def run_step(self):
self.sess.run([self.train_op]) # faster since train_op return None
......
......@@ -43,7 +43,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed)
self.op.run(feed_dict=feed, session=self.sess)
except tf.errors.CancelledError as e:
pass
except Exception:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment