Commit f7e35411 authored by Yuxin Wu's avatar Yuxin Wu

add tests; remove nightly from CI

parent 6e93970f
......@@ -36,7 +36,8 @@ jobs:
max-parallel: 6
matrix:
python-version: [3.6]
TF-version: [1.3.0, 1.14.0, nightly]
# TF-version: [1.3.0, 1.14.0, nightly] # TODO make nightly work
TF-version: [1.3.0, 1.14.0]
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
......
......@@ -13,13 +13,13 @@ from .pool import FixedUnPooling
class TestModel(unittest.TestCase):
def run_variable(self, var):
def eval(self, x, feed_dict=None):
sess = tf.Session()
sess.run(tf.global_variables_initializer())
if isinstance(var, list):
return sess.run(var)
if isinstance(x, list):
return sess.run(x, feed_dict=feed_dict)
else:
return sess.run([var])[0]
return sess.run([x], feed_dict=feed_dict)[0]
def make_variable(self, *args):
if len(args) > 1:
......@@ -36,7 +36,7 @@ class TestPool(TestModel):
input = self.make_variable(mat)
input = tf.reshape(input, [1, h, w, 3])
output = FixedUnPooling('unpool', input, scale)
res = self.run_variable(output)
res = self.eval(output)
self.assertEqual(res.shape, (1, scale * h, scale * w, 3))
# mat is on corner
......@@ -56,7 +56,7 @@ class TestPool(TestModel):
# inp = tf.reshape(inp, [1, h, w, 1])
#
# output = BilinearUpSample(inp, scale)
# res = self.run_variable(output)[0, :, :, 0]
# res = self.eval(output)[0, :, :, 0]
#
# from skimage.transform import rescale
# res2 = rescale(mat, scale, mode='edge')
......@@ -83,9 +83,26 @@ class TestConv2DTranspose(TestModel):
input, 20, 3, strides=stride, padding=padding)
static_shape = output.shape
dynamic_shape = self.run_variable(output).shape
dynamic_shape = self.eval(output).shape
self.assertTrue(static_shape == dynamic_shape)
def test_unspecified_shape_match(self):
h, w = 12, 18
input = tf.placeholder(shape=(1, h, None, 3), dtype=tf.float32)
for padding in ["same", "valid"]:
for stride in [1, 2]:
output = Conv2DTranspose(
'deconv_s{}_pad{}'.format(stride, padding),
input, 20, 3, strides=stride, padding=padding)
static_shape = tuple(output.shape.as_list())
dynamic_shape = self.eval(
output,
feed_dict={input: np.random.rand(1, h, w, 3)}).shape
self.assertTrue(static_shape[2] is None)
self.assertTrue(static_shape[:2] == dynamic_shape[:2])
self.assertTrue(static_shape[3] == dynamic_shape[3])
def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment