Implement generateScalarImplementation for linalg_ext.fft op. (#6600)
This is a step toward https://github.com/google/iree/issues/6477
diff --git a/iree/compiler/Dialect/LinalgExt/IR/BUILD b/iree/compiler/Dialect/LinalgExt/IR/BUILD
index 45921ba..9280bed 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/IR/BUILD
@@ -59,6 +59,7 @@
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index c70e358..a6bd9d1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -33,6 +33,7 @@
MLIRControlFlowInterfaces
MLIRIR
MLIRLinalg
+ MLIRMath
MLIRMemRef
MLIRParser
MLIRSCF
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c42df98..9eaaf67 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -559,6 +560,100 @@
return res;
}
+// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT
+// algorithm. The pseudo reference code is:
+// let s <- stage of linalg_ext.fft
+// int m = 1 << s;
+// int mh = m >> 1;
+// for (int k = 0; k < n; k += m) {
+// for (int j = 0; j < mh; ++j) {
+// cplx w = exp(-2 * PI * j / m * I);
+// cplx t = w * a[k + j + mh];
+// cplx u = a[k + j];
+// a[k + j] = u + t;
+// a[k + j + mh] = u - t;
+// }
+// }
+LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc,
+ ValueRange ivs) {
+ Value real = getReal();
+ Value imag = getImag();
+ Value stage = getStage();
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ Value wholeSize = b.create<ShiftLeftOp>(loc, one, stage);
+ Value halfSize = b.create<SignedShiftRightOp>(loc, wholeSize, one);
+
+ auto rank = getOperandRank();
+ SmallVector<Value> operands;
+ SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end());
+ SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1));
+ sizes.back() = halfSize;
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones));
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones));
+
+ SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end());
+ rhsIvs.back() = b.create<AddIOp>(loc, ivs.back(), halfSize).getResult();
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones));
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones));
+
+ SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank));
+ // TODO(hanchung): Use getLoopIteratorTypes(), once tiling method is
+ // implemented.
+ SmallVector<StringRef> iterTypes(rank, getParallelIteratorTypeName());
+
+ auto f32Type = b.getF32Type();
+ auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value {
+ v = builder.create<IndexCastOp>(loc, builder.getI32Type(), v);
+ return builder.create<SIToFPOp>(loc, builder.getF32Type(), v);
+ };
+
+ // We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part
+ // first.
+ Value coeff = b.create<ConstantFloatOp>(
+ loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type);
+ coeff = b.create<DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize));
+
+ b.create<linalg::GenericOp>(
+ loc, TypeRange{}, ValueRange{}, operands, maps, iterTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value lhsReal = args[0];
+ Value lhsImag = args[1];
+ Value rhsReal = args[2];
+ Value rhsImag = args[3];
+
+ // Compute "-2 * PI / m * j"
+ Value w = b.create<MulFOp>(
+ loc, coeff,
+ indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1)));
+ Value wReal = b.create<math::CosOp>(loc, w);
+ Value wImag = b.create<math::SinOp>(loc, w);
+
+ // t = w * a[k + j + mh];
+ // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+ Value xu = b.create<MulFOp>(loc, wReal, rhsReal);
+ Value yv = b.create<MulFOp>(loc, wImag, rhsImag);
+ Value xv = b.create<MulFOp>(loc, wReal, rhsImag);
+ Value yu = b.create<MulFOp>(loc, wImag, rhsReal);
+ Value tReal = b.create<SubFOp>(loc, xu, yv);
+ Value tImag = b.create<AddFOp>(loc, xv, yu);
+
+ // cplx u = a[k + j];
+ // a[k + j] = u + t;
+ // a[k + j + mh] = u - t;
+ Value r1 = b.create<AddFOp>(loc, lhsReal, tReal);
+ Value r2 = b.create<AddFOp>(loc, lhsImag, tImag);
+ Value r3 = b.create<SubFOp>(loc, lhsReal, tReal);
+ Value r4 = b.create<SubFOp>(loc, lhsImag, tImag);
+ b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
+ });
+ return success();
+}
+
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 89c7292..7bb2a4f 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -179,7 +179,8 @@
}
def LinalgExt_FftOp : LinalgExt_Op<"fft",
- [DeclareOpInterfaceMethods<TiledOpInterface, []>]> {
+ [DeclareOpInterfaceMethods<TiledOpInterface,
+ ["generateScalarImplementation"]>]> {
let summary = "Fft operator";
let description = [{
Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive.
@@ -218,6 +219,8 @@
return getOperandShape().back();
}
Value getStage() { return inputs()[0]; }
+ Value getReal() { return outputs()[0]; }
+ Value getImag() { return outputs()[1]; }
}];
}
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
index 128b12e..5272e4d 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
@@ -49,6 +49,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index e0edb4c..bb4df82 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
+ MLIRMath
MLIRMemRef
MLIRPass
MLIRSCF
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
index 4c994d4..a3da0b7 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -90,8 +91,9 @@
struct LinalgExtToLoopsPass
: public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, StandardOpsDialect,
- memref::MemRefDialect, scf::SCFDialect>();
+ registry
+ .insert<linalg::LinalgDialect, StandardOpsDialect, math::MathDialect,
+ memref::MemRefDialect, scf::SCFDialect>();
}
void runOnOperation() override {
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 52282d4..442456d 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -325,3 +325,103 @@
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = index_cast %[[INDEXVAL]] : i32 to index
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
+
+// -----
+
+func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) {
+ %stage = constant 1 : index
+ linalg_ext.fft
+ ins(%stage: index)
+ outs(%real, %imag: memref<16xf32>, memref<16xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @fft_1D
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C16:.+]] = constant 16 : index
+// CHECK-DAG: %[[SCALE:.+]] = constant -6.28318548 : f32
+// CHECK-DAG: %[[NODE_RNG:.+]] = shift_left %[[C1]], %[[C1]] : index
+// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG: %[[M:.+]] = shift_left %[[C1]], %[[C1]] : index
+// CHECK-DAG: %[[HM:.+]] = shift_right_signed %[[M]], %[[C1]] : index
+// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[K]]] [%[[HM]]] [1]
+// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[K]]] [%[[HM]]] [1]
+// CHECK: %[[R_OFFSET:.+]] = addi %[[K]], %[[HM]] : index
+// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[R_OFFSET]]] [%[[HM]]] [1]
+// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[R_OFFSET]]] [%[[HM]]] [1]
+// CHECK: %[[M_I32:.+]] = index_cast %[[M]] : index to i32
+// CHECK: %[[M_F32:.+]] = sitofp %[[M_I32]] : i32 to f32
+// CHECK: %[[COEFF:.+]] = divf %[[SCALE]], %[[M_F32]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+// CHECK: ^bb0(%[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32)
+//
+// Compute exp coeff.
+// CHECK: %[[J_IDX:.+]] = linalg.index 0 : index
+// CHECK: %[[J_I32:.+]] = index_cast %[[J_IDX]] : index to i32
+// CHECK: %[[J_F32:.+]] = sitofp %[[J_I32]] : i32 to f32
+// CHECK: %[[EXP_COEF:.+]] = mulf %[[COEFF]], %[[J_F32]] : f32
+// CHECK: %[[W_REAL:.+]] = math.cos %[[EXP_COEF]]
+// CHECK: %[[W_IMAG:.+]] = math.sin %[[EXP_COEF]]
+//
+// Compute "t = w * a[k + j + mh]" by expanding
+// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+// CHECK-DAG: %[[XU:.+]] = mulf %[[W_REAL]], %[[R_REAL]]
+// CHECK-DAG: %[[YV:.+]] = mulf %[[W_IMAG]], %[[R_IMAG]]
+// CHECK-DAG: %[[XV:.+]] = mulf %[[W_REAL]], %[[R_IMAG]]
+// CHECK-DAG: %[[YU:.+]] = mulf %[[W_IMAG]], %[[R_REAL]]
+// CHECK: %[[T_REAL:.+]] = subf %[[XU]], %[[YV]]
+// CHECK: %[[T_IMAG:.+]] = addf %[[XV]], %[[YU]]
+//
+// Compute the results.
+// u = a[k + j];
+// a[k + j] = u + t;
+// a[k + j + mh] = u - t;
+// CHECK: %[[RES1:.+]] = addf %[[L_REAL]], %[[T_REAL]]
+// CHECK: %[[RES2:.+]] = addf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK: %[[RES3:.+]] = subf %[[L_REAL]], %[[T_REAL]]
+// CHECK: %[[RES4:.+]] = subf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
+
+// -----
+
+func @fft_2D(%real: memref<?x16xf32>, %imag: memref<?x16xf32>) {
+ %stage = constant 2 : index
+ linalg_ext.fft
+ ins(%stage: index)
+ outs(%real, %imag: memref<?x16xf32>, memref<?x16xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @fft_2D
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref<?x16xf32>
+// CHECK-DAG: %[[NODE_RNG:.+]] = shift_left %[[C1]], %[[C2]] : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
+// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG: %[[M:.+]] = shift_left %[[C1]], %[[C2]] : index
+// CHECK-DAG: %[[HM:.+]] = shift_right_signed %[[M]], %[[C1]] : index
+// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[R_OFFSET:.+]] = addi %[[K]], %[[HM]] : index
+// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+//
+// The computation is bascially the same, and they are
+// checked above. Here only checks the different part.
+// CHECK: %{{.+}} = linalg.index 1 : index