Implement generateScalarImplementation for linalg_ext.fft op. (#6600)

This is a step toward https://github.com/google/iree/issues/6477
diff --git a/iree/compiler/Dialect/LinalgExt/IR/BUILD b/iree/compiler/Dialect/LinalgExt/IR/BUILD
index 45921ba..9280bed 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/IR/BUILD
@@ -59,6 +59,7 @@
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index c70e358..a6bd9d1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -33,6 +33,7 @@
     MLIRControlFlowInterfaces
     MLIRIR
     MLIRLinalg
+    MLIRMath
     MLIRMemRef
     MLIRParser
     MLIRSCF
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c42df98..9eaaf67 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -12,6 +12,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/SMLoc.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -559,6 +560,100 @@
   return res;
 }
 
+// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT
+// algorithm. The pseudo reference code is:
+//   let s <- stage of linalg_ext.fft
+//   int m = 1 << s;
+//   int mh = m >> 1;
+//   for (int k = 0; k < n; k += m) {
+//     for (int j = 0; j < mh; ++j) {
+//       cplx w = exp(-2 * PI * j / m * I);
+//       cplx t = w * a[k + j + mh];
+//       cplx u = a[k + j];
+//       a[k + j] = u + t;
+//       a[k + j + mh] = u - t;
+//     }
+//   }
+LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc,
+                                                  ValueRange ivs) {
+  Value real = getReal();
+  Value imag = getImag();
+  Value stage = getStage();
+  Value one = b.create<ConstantIndexOp>(loc, 1);
+  Value wholeSize = b.create<ShiftLeftOp>(loc, one, stage);
+  Value halfSize = b.create<SignedShiftRightOp>(loc, wholeSize, one);
+
+  auto rank = getOperandRank();
+  SmallVector<Value> operands;
+  SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end());
+  SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1));
+  SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1));
+  sizes.back() = halfSize;
+  operands.push_back(
+      b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones));
+  operands.push_back(
+      b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones));
+
+  SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end());
+  rhsIvs.back() = b.create<AddIOp>(loc, ivs.back(), halfSize).getResult();
+  operands.push_back(
+      b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones));
+  operands.push_back(
+      b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones));
+
+  SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank));
+  // TODO(hanchung): Use getLoopIteratorTypes(), once tiling method is
+  // implemented.
+  SmallVector<StringRef> iterTypes(rank, getParallelIteratorTypeName());
+
+  auto f32Type = b.getF32Type();
+  auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value {
+    v = builder.create<IndexCastOp>(loc, builder.getI32Type(), v);
+    return builder.create<SIToFPOp>(loc, builder.getF32Type(), v);
+  };
+
+  // We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part
+  // first.
+  Value coeff = b.create<ConstantFloatOp>(
+      loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type);
+  coeff = b.create<DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize));
+
+  b.create<linalg::GenericOp>(
+      loc, TypeRange{}, ValueRange{}, operands, maps, iterTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value lhsReal = args[0];
+        Value lhsImag = args[1];
+        Value rhsReal = args[2];
+        Value rhsImag = args[3];
+
+        // Compute "-2 * PI / m * j"
+        Value w = b.create<MulFOp>(
+            loc, coeff,
+            indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1)));
+        Value wReal = b.create<math::CosOp>(loc, w);
+        Value wImag = b.create<math::SinOp>(loc, w);
+
+        // t = w * a[k + j + mh];
+        // ->  (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+        Value xu = b.create<MulFOp>(loc, wReal, rhsReal);
+        Value yv = b.create<MulFOp>(loc, wImag, rhsImag);
+        Value xv = b.create<MulFOp>(loc, wReal, rhsImag);
+        Value yu = b.create<MulFOp>(loc, wImag, rhsReal);
+        Value tReal = b.create<SubFOp>(loc, xu, yv);
+        Value tImag = b.create<AddFOp>(loc, xv, yu);
+
+        // cplx u = a[k + j];
+        // a[k + j] = u + t;
+        // a[k + j + mh] = u - t;
+        Value r1 = b.create<AddFOp>(loc, lhsReal, tReal);
+        Value r2 = b.create<AddFOp>(loc, lhsImag, tImag);
+        Value r3 = b.create<SubFOp>(loc, lhsReal, tReal);
+        Value r4 = b.create<SubFOp>(loc, lhsImag, tImag);
+        b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
+      });
+  return success();
+}
+
 #define DEFINE_OP_GET_EFFECTS(OP_NAME)                                    \
   void OP_NAME::getEffects(                                               \
       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 89c7292..7bb2a4f 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -179,7 +179,8 @@
 }
 
 def LinalgExt_FftOp : LinalgExt_Op<"fft",
-    [DeclareOpInterfaceMethods<TiledOpInterface, []>]> {
+    [DeclareOpInterfaceMethods<TiledOpInterface,
+        ["generateScalarImplementation"]>]> {
   let summary = "Fft operator";
   let description = [{
     Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive.
@@ -218,6 +219,8 @@
       return getOperandShape().back();
     }
     Value getStage() { return inputs()[0]; }
+    Value getReal() { return outputs()[0]; }
+    Value getImag() { return outputs()[1]; }
   }];
 }
 
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
index 128b12e..5272e4d 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
@@ -49,6 +49,7 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
+        "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index e0edb4c..bb4df82 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@
     MLIRIR
     MLIRLinalg
     MLIRLinalgTransforms
+    MLIRMath
     MLIRMemRef
     MLIRPass
     MLIRSCF
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
index 4c994d4..a3da0b7 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
@@ -12,6 +12,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -90,8 +91,9 @@
 struct LinalgExtToLoopsPass
     : public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect, StandardOpsDialect,
-                    memref::MemRefDialect, scf::SCFDialect>();
+    registry
+        .insert<linalg::LinalgDialect, StandardOpsDialect, math::MathDialect,
+                memref::MemRefDialect, scf::SCFDialect>();
   }
 
   void runOnOperation() override {
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 52282d4..442456d 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -325,3 +325,103 @@
 // CHECK:             %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
 // CHECK:             %[[INDEX:.+]] = index_cast %[[INDEXVAL]] : i32 to index
 // CHECK:             memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
+
+// -----
+
+func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) {
+  %stage = constant 1 : index
+  linalg_ext.fft
+    ins(%stage: index)
+    outs(%real, %imag: memref<16xf32>, memref<16xf32>)
+  return
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK:       func @fft_1D
+// CHECK-SAME:    %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+// CHECK-DAG:     %[[C16:.+]] = constant 16 : index
+// CHECK-DAG:     %[[SCALE:.+]] = constant -6.28318548 : f32
+// CHECK-DAG:     %[[NODE_RNG:.+]] = shift_left %[[C1]], %[[C1]] : index
+// CHECK:         scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG:       %[[M:.+]] = shift_left %[[C1]], %[[C1]] : index
+// CHECK-DAG:       %[[HM:.+]] = shift_right_signed %[[M]], %[[C1]] : index
+// CHECK:           %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[K]]] [%[[HM]]] [1]
+// CHECK:           %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[K]]] [%[[HM]]] [1]
+// CHECK:           %[[R_OFFSET:.+]] = addi %[[K]], %[[HM]] : index
+// CHECK:           %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[R_OFFSET]]] [%[[HM]]] [1]
+// CHECK:           %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[R_OFFSET]]] [%[[HM]]] [1]
+// CHECK:           %[[M_I32:.+]] = index_cast %[[M]] : index to i32
+// CHECK:           %[[M_F32:.+]] = sitofp %[[M_I32]] : i32 to f32
+// CHECK:           %[[COEFF:.+]] = divf %[[SCALE]], %[[M_F32]]
+// CHECK:           linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME:        iterator_types = ["parallel"]
+// CHECK-SAME:        outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+// CHECK:           ^bb0(%[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32)
+//
+//                    Compute exp coeff.
+// CHECK:             %[[J_IDX:.+]] = linalg.index 0 : index
+// CHECK:             %[[J_I32:.+]] = index_cast %[[J_IDX]] : index to i32
+// CHECK:             %[[J_F32:.+]] = sitofp %[[J_I32]] : i32 to f32
+// CHECK:             %[[EXP_COEF:.+]] = mulf %[[COEFF]], %[[J_F32]] : f32
+// CHECK:             %[[W_REAL:.+]] = math.cos %[[EXP_COEF]]
+// CHECK:             %[[W_IMAG:.+]] = math.sin %[[EXP_COEF]]
+//
+//                    Compute "t = w * a[k + j + mh]" by expanding
+//                      (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+// CHECK-DAG:         %[[XU:.+]] = mulf %[[W_REAL]], %[[R_REAL]]
+// CHECK-DAG:         %[[YV:.+]] = mulf %[[W_IMAG]], %[[R_IMAG]]
+// CHECK-DAG:         %[[XV:.+]] = mulf %[[W_REAL]], %[[R_IMAG]]
+// CHECK-DAG:         %[[YU:.+]] = mulf %[[W_IMAG]], %[[R_REAL]]
+// CHECK:             %[[T_REAL:.+]] = subf %[[XU]], %[[YV]]
+// CHECK:             %[[T_IMAG:.+]] = addf %[[XV]], %[[YU]]
+//
+//                    Compute the results.
+//                      u = a[k + j];
+//                      a[k + j] = u + t;
+//                      a[k + j + mh] = u - t;
+// CHECK:             %[[RES1:.+]] = addf %[[L_REAL]], %[[T_REAL]]
+// CHECK:             %[[RES2:.+]] = addf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK:             %[[RES3:.+]] = subf %[[L_REAL]], %[[T_REAL]]
+// CHECK:             %[[RES4:.+]] = subf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK:             linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
+
+// -----
+
+func @fft_2D(%real: memref<?x16xf32>, %imag: memref<?x16xf32>) {
+  %stage = constant 2 : index
+  linalg_ext.fft
+    ins(%stage: index)
+    outs(%real, %imag: memref<?x16xf32>, memref<?x16xf32>)
+  return
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK:       func @fft_2D
+// CHECK-SAME:    %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+// CHECK-DAG:     %[[C2:.+]] = constant 2 : index
+// CHECK-DAG:     %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref<?x16xf32>
+// CHECK-DAG:     %[[NODE_RNG:.+]] = shift_left %[[C1]], %[[C2]] : index
+// CHECK:         scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
+// CHECK:           scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG:         %[[M:.+]] = shift_left %[[C1]], %[[C2]] : index
+// CHECK-DAG:         %[[HM:.+]] = shift_right_signed %[[M]], %[[C1]] : index
+// CHECK:             %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK:             %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK:             %[[R_OFFSET:.+]] = addi %[[K]], %[[HM]] : index
+// CHECK:             %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK:             %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK:             linalg.generic
+// CHECK-SAME:          indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME:          iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:          outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+//
+//                    The computation is bascially the same, and they are
+//                    checked above. Here only checks the different part.
+// CHECK:             %{{.+}} = linalg.index 1 : index