Commit ed4e5106 authored by ppwwyyxx's avatar ppwwyyxx

fix some bugs

parent c713cb75
......@@ -16,7 +16,7 @@ class ImageFromFile(DataFlow):
channel: 1 or 3 channel
resize: a (h, w) tuple. If given, will force a resize
"""
assert len(self.files)
assert len(files)
self.files = files
self.channel = int(channel)
self.resize = resize
......
......@@ -39,7 +39,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
input_shape = x.get_shape().as_list()
assert len(input_shape) == 4
if unpool_mat is None:
mat = np.zeros(shape)
mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1
unpool_mat = tf.Variable(mat, trainable=False, name='unpool_mat')
assert unpool_mat.get_shape().as_list() == list(shape)
......
......@@ -103,6 +103,8 @@ def start_train(config):
def get_model_inputs():
model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor):
model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars):
if config.batched_model_input:
qv.set_shape(v.get_shape())
......@@ -179,7 +181,7 @@ def start_train(config):
raise
finally:
coord.request_stop()
queue.close(cancel_pending_enqueues=True)
input_queue.close(cancel_pending_enqueues=True)
callbacks.after_train()
sess.close()
......@@ -55,8 +55,9 @@ def set_file(path):
global LOG_FILE
LOG_FILE = "train_log/log.log"
def set_logger_file(filename):
global LOG_FILE
global LOG_FILE, LOG_DIR
LOG_FILE = filename
LOG_DIR = os.path.dirname(LOG_FILE)
mkdir_p(os.path.dirname(LOG_FILE))
set_file(LOG_FILE)
......@@ -18,8 +18,7 @@ def one_hot(y, num_labels):
return tf.cast(onehot_labels, tf.float32)
def flatten(x):
total_dim = np.prod(x.get_shape().as_list())
return tf.reshape(x, [total_dim])
return tf.reshape(x, [-1])
def batch_flatten(x):
total_dim = np.prod(x.get_shape()[1:].as_list())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment