blob: 44ec106dfe40694d8d0729e4579847118a24c5bd [file] [log] [blame]
Jakub Kuderski60197312023-05-25 16:56:59 -04001// TODO(hanchung): Add other types of fft tests, e.g. fft, ifft, irfft.
2
3func.func @rfft_1d() {
4 %input = util.unfoldable_constant dense<[
5 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7,
6 3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5,
7 -0.777, 2.0, 1.7, 3.5, -4.5, 0.0]> : tensor<32xf32>
8 %0 = stablehlo.fft %input, type = RFFT, length = [32] : (tensor<32xf32>) -> tensor<17xcomplex<f32>>
9 %1 = stablehlo.real %0 : (tensor<17xcomplex<f32>>) -> tensor<17xf32>
10 %2 = stablehlo.imag %0 : (tensor<17xcomplex<f32>>) -> tensor<17xf32>
11 check.expect_almost_eq_const(%1, dense<[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]> : tensor<17xf32>) : tensor<17xf32>
12 check.expect_almost_eq_const(%2, dense<[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]> : tensor<17xf32>) : tensor<17xf32>
13 return
14}
15
16func.func @rfft_2d() {
17 %input = util.unfoldable_constant dense<[[
18 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7,
19 3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5,
20 -0.777, 2.0, 1.7, 3.5, -4.5, 0.0]]> : tensor<1x32xf32>
21 %0 = stablehlo.fft %input, type = RFFT, length = [32] : (tensor<1x32xf32>) -> tensor<1x17xcomplex<f32>>
22 %1 = stablehlo.real %0 : (tensor<1x17xcomplex<f32>>) -> tensor<1x17xf32>
23 %2 = stablehlo.imag %0 : (tensor<1x17xcomplex<f32>>) -> tensor<1x17xf32>
24 check.expect_almost_eq_const(%1, dense<[[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]]> : tensor<1x17xf32>) : tensor<1x17xf32>
25 check.expect_almost_eq_const(%2, dense<[[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]]> : tensor<1x17xf32>) : tensor<1x17xf32>
26 return
27}