Create VMLA_FftOp and generate the lowering from VMLA_FftPseudoOp to it. (#3480)
(The actual Implementation for FftOp is still WIP)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index b3982b0..6bd35bb 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -43,6 +43,7 @@
// If we end up with a lot of these, consider using an "is pseudo" trait.
addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
addIllegalOp<IREE::VMLA::SortPseudoOp>();
+ addIllegalOp<IREE::VMLA::FftPseudoOp>();
// Allow other ops to pass through so long as their type is valid (not a
// tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 6f43d6e..5b842a5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -703,6 +703,43 @@
TypeConverter &typeConverter;
};
+struct FftOpConversion : public OpConversionPattern<IREE::VMLA::FftPseudoOp> {
+ FftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ LogicalResult matchAndRewrite(
+ IREE::VMLA::FftPseudoOp srcOp, ArrayRef<Value> rawOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto input_shape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
+
+ auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
+ auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
+
+ // The input type/shape should match for the real and imag components.
+ if (real_input_type != imag_input_type) {
+ srcOp.emitWarning() << "real and imag should have matching types";
+ return failure();
+ }
+
+ auto real_out = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
+ auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
+
+ rewriter.createOrFold<IREE::VMLA::FftOp>(
+ srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
+ input_shape, real_out, imag_out,
+ TypeAttr::get(real_input_type.getElementType()),
+ TypeAttr::get(imag_input_type.getElementType()));
+
+ rewriter.replaceOp(srcOp, {real_out, imag_out});
+ return success();
+ }
+
+ TypeConverter &typeConverter;
+};
+
struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
using OpConversionPattern::OpConversionPattern;
@@ -769,6 +806,9 @@
// vmla.sort.pseudo
patterns.insert<SortOpConversion>(context, typeConverter);
+ // vmla.fft.pseudo
+ patterns.insert<FftOpConversion>(context, typeConverter);
+
// Simple 1:1 conversion patterns using the automated trait-based converter.
// Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
new file mode 100644
index 0000000..1e3c365
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) attributes { sym_visibility = "private" } {
+ // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
+ // CHECK-NEXT: [[C32:%.+]] = constant 32 : index
+ // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+ // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+ // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32, f32
+ %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>)
+ return %real, %imag : tensor<8xf32>, tensor<8xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f31545d..f139690 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -694,5 +694,24 @@
}];
}
+def VMLA_FftOp : VMLA_ElementTypeOp<"fft", [VMLA_IncludeShapes]> {
+ let arguments = (ins
+ VMLA_Buffer:$real_in,
+ VMLA_Shape:$real_in_shape,
+ VMLA_Buffer:$imag_in,
+ VMLA_Shape:$imag_in_shape,
+ VMLA_Buffer:$real_out,
+ VMLA_Buffer:$imag_out,
+ VMLA_AnyTypeAttr:$real_element_type,
+ VMLA_AnyTypeAttr:$imag_element_type
+ );
+
+ let assemblyFormat = [{
+ $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
+ $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
+ `out` $real_out `,` $imag_out attr-dict `:` $real_element_type `,` $imag_element_type
+ }];
+}
+
#endif // IREE_DIALECT_VMLA_OPS