Commit f7a79d48 authored by Yuxin Wu's avatar Yuxin Wu

Add shape test for Conv2DTranspose

parent e8e8b014
......@@ -7,6 +7,7 @@ import unittest
import tensorflow as tf
import numpy as np
from .conv2d import Conv2DTranspose
from .pool import FixedUnPooling
......@@ -32,9 +33,9 @@ class TestPool(TestModel):
h, w = 3, 4
scale = 2
mat = np.random.rand(h, w, 3).astype('float32')
inp = self.make_variable(mat)
inp = tf.reshape(inp, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, scale)
input = self.make_variable(mat)
input = tf.reshape(input, [1, h, w, 3])
output = FixedUnPooling('unpool', input, scale)
res = self.run_variable(output)
self.assertEqual(res.shape, (1, scale * h, scale * w, 3))
......@@ -68,16 +69,30 @@ class TestPool(TestModel):
# self.assertTrue(diff.max() < 1e-4, diff.max())
class TestConv2DTranspose(TestModel):
def setUp(self):
tf.reset_default_graph()
def test_shape_match(self):
h, w = 12, 18
input = self.make_variable(np.random.rand(1, h, w, 3).astype("float32"))
for padding in ["same"]:
for stride in [1, 2]:
output = Conv2DTranspose(
'deconv_s{}_pad{}'.format(stride, padding),
input, 20, 3, strides=stride, padding=padding)
static_shape = output.shape
dynamic_shape = self.run_variable(output).shape
self.assertTrue(static_shape == dynamic_shape)
def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case)
unittest.TextTestRunner(verbosity=2).run(suite)
if __name__ == '__main__':
import tensorpack
from tensorpack.utils import logger
from . import * # noqa
logger.setLevel(logging.CRITICAL)
subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs:
run_test_case(cls)
unittest.main()
......@@ -206,8 +206,8 @@ class NVMLContext(object):
if __name__ == '__main__':
with NVMLContext() as ctx:
print(ctx.devices())
print(ctx.devices()[0].utilization())
for idx, dev in enumerate(ctx.devices()):
print(idx, dev.name())
with NVMLContext() as ctx:
print(ctx.devices())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment