Commit f7a79d48 authored by Yuxin Wu's avatar Yuxin Wu

Add shape test for Conv2DTranspose

parent e8e8b014
...@@ -7,6 +7,7 @@ import unittest ...@@ -7,6 +7,7 @@ import unittest
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from .conv2d import Conv2DTranspose
from .pool import FixedUnPooling from .pool import FixedUnPooling
...@@ -32,9 +33,9 @@ class TestPool(TestModel): ...@@ -32,9 +33,9 @@ class TestPool(TestModel):
h, w = 3, 4 h, w = 3, 4
scale = 2 scale = 2
mat = np.random.rand(h, w, 3).astype('float32') mat = np.random.rand(h, w, 3).astype('float32')
inp = self.make_variable(mat) input = self.make_variable(mat)
inp = tf.reshape(inp, [1, h, w, 3]) input = tf.reshape(input, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, scale) output = FixedUnPooling('unpool', input, scale)
res = self.run_variable(output) res = self.run_variable(output)
self.assertEqual(res.shape, (1, scale * h, scale * w, 3)) self.assertEqual(res.shape, (1, scale * h, scale * w, 3))
...@@ -68,16 +69,30 @@ class TestPool(TestModel): ...@@ -68,16 +69,30 @@ class TestPool(TestModel):
# self.assertTrue(diff.max() < 1e-4, diff.max()) # self.assertTrue(diff.max() < 1e-4, diff.max())
class TestConv2DTranspose(TestModel):
def setUp(self):
tf.reset_default_graph()
def test_shape_match(self):
h, w = 12, 18
input = self.make_variable(np.random.rand(1, h, w, 3).astype("float32"))
for padding in ["same"]:
for stride in [1, 2]:
output = Conv2DTranspose(
'deconv_s{}_pad{}'.format(stride, padding),
input, 20, 3, strides=stride, padding=padding)
static_shape = output.shape
dynamic_shape = self.run_variable(output).shape
self.assertTrue(static_shape == dynamic_shape)
def run_test_case(case): def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case) suite = unittest.TestLoader().loadTestsFromTestCase(case)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
if __name__ == '__main__': if __name__ == '__main__':
import tensorpack
from tensorpack.utils import logger from tensorpack.utils import logger
from . import * # noqa
logger.setLevel(logging.CRITICAL) logger.setLevel(logging.CRITICAL)
subs = tensorpack.models._test.TestModel.__subclasses__() unittest.main()
for cls in subs:
run_test_case(cls)
...@@ -206,8 +206,8 @@ class NVMLContext(object): ...@@ -206,8 +206,8 @@ class NVMLContext(object):
if __name__ == '__main__': if __name__ == '__main__':
with NVMLContext() as ctx: with NVMLContext() as ctx:
print(ctx.devices()) for idx, dev in enumerate(ctx.devices()):
print(ctx.devices()[0].utilization()) print(idx, dev.name())
with NVMLContext() as ctx: with NVMLContext() as ctx:
print(ctx.devices()) print(ctx.devices())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment