Commit 4185a222 authored by ppwwyyxx's avatar ppwwyyxx

fix some errors

parent 09e99778
......@@ -35,7 +35,7 @@ def Conv2D(x, out_channel, kernel_shape,
if b_init is None:
b_init = tf.constant_initializer()
W = tf.get_variable('W', filter_shape, initializer=W_init) # TODO collections
W = tf.get_variable('W', filter_shape, initializer=W_init)
b = tf.get_variable('b', [out_channel], initializer=b_init)
if split == 1:
......
......@@ -23,4 +23,4 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
W = tf.get_variable('W', [in_dim, out_dim], initializer=W_init)
b = tf.get_variable('b', [out_dim], initializer=b_init)
return nl(tf.matmul(x, W) + b)
return nl(tf.matmul(x, W) + b, name=tf.get_variable_scope().name + '_output')
......@@ -87,7 +87,7 @@ def start_train(config):
# start training:
coord = tf.train.Coordinator()
# a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, dataset)
input_th = EnqueueThread(sess, coord, enqueue_op, dataset, input_queue)
model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True)
input_th.start()
......@@ -101,7 +101,9 @@ def start_train(config):
if coord.should_stop():
return
fetches = [train_op, cost_var] + output_vars + model_inputs
print 'before'
results = sess.run(fetches)
print 'after'
cost = results[1]
outputs = results[2:2 + len(output_vars)]
inputs = results[-len(model_inputs):]
......
......@@ -25,13 +25,15 @@ class StoppableThread(threading.Thread):
class EnqueueThread(threading.Thread):
def __init__(self, sess, coord, enqueue_op, dataflow):
def __init__(self, sess, coord, enqueue_op, dataflow, queue):
super(EnqueueThread, self).__init__()
self.sess = sess
self.coord = coord
self.input_vars = sess.graph.get_collection(INPUT_VARS_KEY)
self.dataflow = dataflow
self.op = enqueue_op
self.queue = queue
self.daemon = True
def run(self):
......@@ -45,8 +47,8 @@ class EnqueueThread(threading.Thread):
except tf.errors.CancelledError as e:
pass
except Exception:
# TODO close queue.
logger.exception("Exception in EnqueueThread:")
self.queue.close(cancel_pending_enqueues=True)
self.coord.request_stop()
@contextmanager
......
......@@ -35,6 +35,7 @@ class ParamRestore(SessionInit):
self.prms = param_dict
def init(self, sess):
sess.run(tf.initialize_all_variables())
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_dict = dict([v.name, v] for v in variables)
for name, value in self.prms.iteritems():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment