Create VMLA_FftOp and generate the lowering from VMLA_FftPseudoOp to it. (#3480)

(The actual Implementation for FftOp is still WIP)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index b3982b0..6bd35bb 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -43,6 +43,7 @@
   // If we end up with a lot of these, consider using an "is pseudo" trait.
   addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
   addIllegalOp<IREE::VMLA::SortPseudoOp>();
+  addIllegalOp<IREE::VMLA::FftPseudoOp>();
 
   // Allow other ops to pass through so long as their type is valid (not a
   // tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 6f43d6e..5b842a5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -703,6 +703,43 @@
   TypeConverter &typeConverter;
 };
 
+struct FftOpConversion : public OpConversionPattern<IREE::VMLA::FftPseudoOp> {
+  FftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  LogicalResult matchAndRewrite(
+      IREE::VMLA::FftPseudoOp srcOp, ArrayRef<Value> rawOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto input_shape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
+
+    auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
+    auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
+
+    // The input type/shape should match for the real and imag components.
+    if (real_input_type != imag_input_type) {
+      srcOp.emitWarning() << "real and imag should have matching types";
+      return failure();
+    }
+
+    auto real_out = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
+    auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
+
+    rewriter.createOrFold<IREE::VMLA::FftOp>(
+        srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
+        input_shape, real_out, imag_out,
+        TypeAttr::get(real_input_type.getElementType()),
+        TypeAttr::get(imag_input_type.getElementType()));
+
+    rewriter.replaceOp(srcOp, {real_out, imag_out});
+    return success();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -769,6 +806,9 @@
   // vmla.sort.pseudo
   patterns.insert<SortOpConversion>(context, typeConverter);
 
+  // vmla.fft.pseudo
+  patterns.insert<FftOpConversion>(context, typeConverter);
+
   // Simple 1:1 conversion patterns using the automated trait-based converter.
   // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
   patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
new file mode 100644
index 0000000..1e3c365
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->  (tensor<8xf32>, tensor<8xf32>) attributes { sym_visibility = "private" } {
+  // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
+  // CHECK-NEXT: [[C32:%.+]] = constant 32 : index
+  // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+  // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+  // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>),  %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32, f32
+  %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>)
+  return %real, %imag : tensor<8xf32>, tensor<8xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f31545d..f139690 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -694,5 +694,24 @@
   }];
 }
 
+def VMLA_FftOp : VMLA_ElementTypeOp<"fft", [VMLA_IncludeShapes]> {
+  let arguments = (ins
+    VMLA_Buffer:$real_in,
+    VMLA_Shape:$real_in_shape,
+    VMLA_Buffer:$imag_in,
+    VMLA_Shape:$imag_in_shape,
+    VMLA_Buffer:$real_out,
+    VMLA_Buffer:$imag_out,
+    VMLA_AnyTypeAttr:$real_element_type,
+    VMLA_AnyTypeAttr:$imag_element_type
+  );
+
+  let assemblyFormat = [{
+    $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
+    $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
+    `out` $real_out `,` $imag_out attr-dict `:` $real_element_type `,` $imag_element_type
+  }];
+}
+
 
 #endif  // IREE_DIALECT_VMLA_OPS