Commit ad5a720a authored by Yuxin Wu's avatar Yuxin Wu

fix shape bug in bilinear

parent 0deff876
...@@ -43,4 +43,4 @@ To visualize the agent: ...@@ -43,4 +43,4 @@ To visualize the agent:
./DQN.py --rom breakout.bin --task play --load pretrained.model ./DQN.py --rom breakout.bin --task play --load pretrained.model
``` ```
A3C code will be released very soon. A3C code and models for Atari games in OpenAI Gym are released in [examples/OpenAIGym](../OpenAIGym)
# tensorpack examples # tensorpack examples
Examples with __reproducible__ and meaningful performancce. Examples with __reproducible__ and meaningful performance.
+ [An illustrative mnist example](mnist-convnet.py) + [An illustrative mnist example](mnist-convnet.py)
+ [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py) + [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py)
......
...@@ -136,7 +136,9 @@ def BilinearUpSample(x, shape): ...@@ -136,7 +136,9 @@ def BilinearUpSample(x, shape):
ret[x,y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) ret[x,y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
return ret return ret
ch = x.get_shape().as_list()[3] inp_shape = x.get_shape().as_list()
ch = inp_shape[3]
assert ch is not None
shape = int(shape) shape = int(shape)
filter_shape = 2 * shape filter_shape = 2 * shape
...@@ -144,10 +146,15 @@ def BilinearUpSample(x, shape): ...@@ -144,10 +146,15 @@ def BilinearUpSample(x, shape):
w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch)) w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch))
weight_var = tf.constant(w, tf.float32, weight_var = tf.constant(w, tf.float32,
shape=(filter_shape, filter_shape, ch, ch)) shape=(filter_shape, filter_shape, ch, ch))
return tf.nn.conv2d_transpose(x, weight_var, deconv = tf.nn.conv2d_transpose(x, weight_var,
tf.shape(x) * tf.constant([1, shape, shape, 1], tf.int32), tf.shape(x) * tf.constant([1, shape, shape, 1], tf.int32),
[1,shape,shape,1], 'SAME') [1,shape,shape,1], 'SAME')
if inp_shape[1]: inp_shape[1] *= shape
if inp_shape[2]: inp_shape[2] *= shape
deconv.set_shape(inp_shape)
return deconv
from ._test import TestModel from ._test import TestModel
class TestPool(TestModel): class TestPool(TestModel):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment