Jakub Kuderski | 6019731 | 2023-05-25 16:56:59 -0400 | [diff] [blame] | 1 | // TODO(hanchung): Add other types of fft tests, e.g. fft, ifft, irfft. |
| 2 | |
| 3 | func.func @rfft_1d() { |
| 4 | %input = util.unfoldable_constant dense<[ |
| 5 | 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7, |
| 6 | 3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, |
| 7 | -0.777, 2.0, 1.7, 3.5, -4.5, 0.0]> : tensor<32xf32> |
| 8 | %0 = stablehlo.fft %input, type = RFFT, length = [32] : (tensor<32xf32>) -> tensor<17xcomplex<f32>> |
| 9 | %1 = stablehlo.real %0 : (tensor<17xcomplex<f32>>) -> tensor<17xf32> |
| 10 | %2 = stablehlo.imag %0 : (tensor<17xcomplex<f32>>) -> tensor<17xf32> |
| 11 | check.expect_almost_eq_const(%1, dense<[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]> : tensor<17xf32>) : tensor<17xf32> |
| 12 | check.expect_almost_eq_const(%2, dense<[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]> : tensor<17xf32>) : tensor<17xf32> |
| 13 | return |
| 14 | } |
| 15 | |
| 16 | func.func @rfft_2d() { |
| 17 | %input = util.unfoldable_constant dense<[[ |
| 18 | 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7, |
| 19 | 3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, |
| 20 | -0.777, 2.0, 1.7, 3.5, -4.5, 0.0]]> : tensor<1x32xf32> |
| 21 | %0 = stablehlo.fft %input, type = RFFT, length = [32] : (tensor<1x32xf32>) -> tensor<1x17xcomplex<f32>> |
| 22 | %1 = stablehlo.real %0 : (tensor<1x17xcomplex<f32>>) -> tensor<1x17xf32> |
| 23 | %2 = stablehlo.imag %0 : (tensor<1x17xcomplex<f32>>) -> tensor<1x17xf32> |
| 24 | check.expect_almost_eq_const(%1, dense<[[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]]> : tensor<1x17xf32>) : tensor<1x17xf32> |
| 25 | check.expect_almost_eq_const(%2, dense<[[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]]> : tensor<1x17xf32>) : tensor<1x17xf32> |
| 26 | return |
| 27 | } |