From cc0f4f9e1f5a49619ae6e2bb4387512e732649ef Mon Sep 17 00:00:00 2001 From: cboss6 Date: Tue, 16 Apr 2024 09:24:58 +0800 Subject: [PATCH] [GPU][Fix] Add more shapes in softmax ut. (#2675) --- test/sanity/nn/softmax_op_test.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/test/sanity/nn/softmax_op_test.py b/test/sanity/nn/softmax_op_test.py index 0bef0a72f..d90f7e445 100644 --- a/test/sanity/nn/softmax_op_test.py +++ b/test/sanity/nn/softmax_op_test.py @@ -131,8 +131,12 @@ def _testOverflow(self, use_gpu=False): atol=1.e-5) def testFloat(self): - self._testAll( - np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32)) + features = [np.random.randn(256, 10480).astype(np.float32), + np.random.randn(16, 16).astype(np.float32), + np.random.randn(32*512, 32*512).astype(np.float32)] + for feature in features: + self._testAll(feature) + @unittest.skipUnless(test.is_built_with_gpu_support(), "Test only applicable when running on GPUs") @@ -146,8 +150,11 @@ def testFloatGPU(self): self._testAll(data.astype(np.float32)) def testHalf(self): - self._testAll( - np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16)) + features = [np.random.randn(256, 10480).astype(np.float16), + np.random.randn(16, 16).astype(np.float16), + np.random.randn(32*512, 32*512).astype(np.float16)] + for feature in features: + self._testAll(feature) @unittest.skipUnless(test.is_built_with_gpu_support(), "Test only applicable when running on GPUs") @@ -177,9 +184,11 @@ def testDoubleGPU(self): self._testAll(data.astype(np.float64)) def testBfloat16(self): - self._testAll( - np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32), - dtype=dtypes.bfloat16) + features = [np.random.randn(256, 10480).astype(np.float32), + np.random.randn(16, 16).astype(np.float32), + np.random.randn(32*512, 32*512).astype(np.float32)] + for feature in features: + self._testAll(feature, dtype=dtypes.bfloat16) @unittest.skipUnless(test.is_built_with_gpu_support(), "Test only applicable when running on GPUs")