Commit 322449d2 authored by Yuxin Wu's avatar Yuxin Wu

fix saver/thread order bug

parent 7237a1c8
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import copy
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
......
...@@ -62,14 +62,19 @@ class Trainer(object): ...@@ -62,14 +62,19 @@ class Trainer(object):
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self): def main_loop(self):
# some final operations that might modify the graph
self._init_summary() self._init_summary()
get_global_step_var()
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.before_train(self) callbacks.before_train(self)
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
self._start_all_threads()
with self.sess.as_default(): with self.sess.as_default():
try: try:
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch): for epoch in xrange(1, self.config.max_epoch):
with timed_operation( with timed_operation(
...@@ -97,10 +102,12 @@ class Trainer(object): ...@@ -97,10 +102,12 @@ class Trainer(object):
def init_session_and_coord(self): def init_session_and_coord(self):
describe_model() describe_model()
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
def _start_all_threads(self):
"""
Run all threads before starting training
"""
tf.train.start_queue_runners( tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
...@@ -142,12 +142,14 @@ class QueueInputTrainer(Trainer): ...@@ -142,12 +142,14 @@ class QueueInputTrainer(Trainer):
avg_maintain_op) avg_maintain_op)
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
input_th = EnqueueThread(self, input_queue, enqueue_op, input_vars) self.input_th = EnqueueThread(self, input_queue, enqueue_op, input_vars)
input_th.start()
self.main_loop() self.main_loop()
def _start_all_threads(self):
super(QueueInputTrainer, self)._start_all_threads()
self.input_th.start()
def run_step(self): def run_step(self):
self.sess.run([self.train_op]) # faster since train_op return None self.sess.run([self.train_op]) # faster since train_op return None
......
...@@ -43,7 +43,7 @@ class EnqueueThread(threading.Thread): ...@@ -43,7 +43,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed) self.op.run(feed_dict=feed, session=self.sess)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
except Exception: except Exception:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment