Commit 6607d856 authored by Yuxin Wu's avatar Yuxin Wu

fix unpool unknown shape problem

parent e51855c5
...@@ -32,6 +32,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -32,6 +32,7 @@ def Conv2D(x, out_channel, kernel_shape,
""" """
in_shape = x.get_shape().as_list() in_shape = x.get_shape().as_list()
in_channel = in_shape[-1] in_channel = in_shape[-1]
assert in_channel is not None, "Input to Conv2D cannot have unknown channel!"
assert in_channel % split == 0 assert in_channel % split == 0
assert out_channel % split == 0 assert out_channel % split == 0
......
...@@ -74,7 +74,9 @@ def UnPooling2x2ZeroFilled(x): ...@@ -74,7 +74,9 @@ def UnPooling2x2ZeroFilled(x):
return tf.reshape(out, out_size) return tf.reshape(out, out_size)
else: else:
sh = tf.shape(x) sh = tf.shape(x)
return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]]) ret = tf.reshape(out, tf.pack([-1, sh[1] * 2, sh[2] * 2, sh[3]]))
ret.set_shape([None, None, None, sh[3]])
return ret
@layer_register() @layer_register()
def FixedUnPooling(x, shape, unpool_mat=None): def FixedUnPooling(x, shape, unpool_mat=None):
......
...@@ -90,5 +90,6 @@ def get_predict_func(config): ...@@ -90,5 +90,6 @@ def get_predict_func(config):
def run_input(dp): def run_input(dp):
feed = dict(zip(input_map, dp)) feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed) return sess.run(output_vars, feed_dict=feed)
# XXX hack. so the caller can get access to the session.
run_input.session = sess run_input.session = sess
return run_input return run_input
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment