Commit ed4e5106 authored by ppwwyyxx's avatar ppwwyyxx

fix some bugs

parent c713cb75
...@@ -16,7 +16,7 @@ class ImageFromFile(DataFlow): ...@@ -16,7 +16,7 @@ class ImageFromFile(DataFlow):
channel: 1 or 3 channel channel: 1 or 3 channel
resize: a (h, w) tuple. If given, will force a resize resize: a (h, w) tuple. If given, will force a resize
""" """
assert len(self.files) assert len(files)
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
self.resize = resize self.resize = resize
......
...@@ -39,7 +39,7 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -39,7 +39,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
input_shape = x.get_shape().as_list() input_shape = x.get_shape().as_list()
assert len(input_shape) == 4 assert len(input_shape) == 4
if unpool_mat is None: if unpool_mat is None:
mat = np.zeros(shape) mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1 mat[0][0] = 1
unpool_mat = tf.Variable(mat, trainable=False, name='unpool_mat') unpool_mat = tf.Variable(mat, trainable=False, name='unpool_mat')
assert unpool_mat.get_shape().as_list() == list(shape) assert unpool_mat.get_shape().as_list() == list(shape)
......
...@@ -103,6 +103,8 @@ def start_train(config): ...@@ -103,6 +103,8 @@ def start_train(config):
def get_model_inputs(): def get_model_inputs():
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
if isinstance(model_inputs, tf.Tensor):
model_inputs = [model_inputs]
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
if config.batched_model_input: if config.batched_model_input:
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
...@@ -179,7 +181,7 @@ def start_train(config): ...@@ -179,7 +181,7 @@ def start_train(config):
raise raise
finally: finally:
coord.request_stop() coord.request_stop()
queue.close(cancel_pending_enqueues=True) input_queue.close(cancel_pending_enqueues=True)
callbacks.after_train() callbacks.after_train()
sess.close() sess.close()
...@@ -55,8 +55,9 @@ def set_file(path): ...@@ -55,8 +55,9 @@ def set_file(path):
global LOG_FILE global LOG_FILE
LOG_FILE = "train_log/log.log" LOG_FILE = "train_log/log.log"
def set_logger_file(filename): def set_logger_file(filename):
global LOG_FILE global LOG_FILE, LOG_DIR
LOG_FILE = filename LOG_FILE = filename
LOG_DIR = os.path.dirname(LOG_FILE)
mkdir_p(os.path.dirname(LOG_FILE)) mkdir_p(os.path.dirname(LOG_FILE))
set_file(LOG_FILE) set_file(LOG_FILE)
...@@ -18,8 +18,7 @@ def one_hot(y, num_labels): ...@@ -18,8 +18,7 @@ def one_hot(y, num_labels):
return tf.cast(onehot_labels, tf.float32) return tf.cast(onehot_labels, tf.float32)
def flatten(x): def flatten(x):
total_dim = np.prod(x.get_shape().as_list()) return tf.reshape(x, [-1])
return tf.reshape(x, [total_dim])
def batch_flatten(x): def batch_flatten(x):
total_dim = np.prod(x.get_shape()[1:].as_list()) total_dim = np.prod(x.get_shape()[1:].as_list())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment