Commit 6607d856 authored by Yuxin Wu's avatar Yuxin Wu

fix unpool unknown shape problem

parent e51855c5
......@@ -32,6 +32,7 @@ def Conv2D(x, out_channel, kernel_shape,
"""
in_shape = x.get_shape().as_list()
in_channel = in_shape[-1]
assert in_channel is not None, "Input to Conv2D cannot have unknown channel!"
assert in_channel % split == 0
assert out_channel % split == 0
......
......@@ -74,7 +74,9 @@ def UnPooling2x2ZeroFilled(x):
return tf.reshape(out, out_size)
else:
sh = tf.shape(x)
return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]])
ret = tf.reshape(out, tf.pack([-1, sh[1] * 2, sh[2] * 2, sh[3]]))
ret.set_shape([None, None, None, sh[3]])
return ret
@layer_register()
def FixedUnPooling(x, shape, unpool_mat=None):
......
......@@ -90,5 +90,6 @@ def get_predict_func(config):
def run_input(dp):
feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed)
# XXX hack. so the caller can get access to the session.
run_input.session = sess
return run_input
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment