Add support for padding by complex values (#14367)

When performing a pad on a linalg.* operation we need to generate the
approrpiate zero attr. Added support for these types though the LLVM CPU
stack and bufferization work.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTensorPad.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTensorPad.cpp
index aa1f6fe..219dfc2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTensorPad.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTensorPad.cpp
@@ -84,6 +84,12 @@
     OpBuilder builder(context);
     for (auto &operand : linalgOp->getOpOperands()) {
       auto elemType = getElementTypeOrSelf(operand.get().getType());
+      if (auto complexTy = elemType.dyn_cast<ComplexType>()) {
+        auto zeroAttr = builder.getZeroAttr(complexTy.getElementType());
+        paddingValueAttributes.push_back(
+            ArrayAttr::get(context, {zeroAttr, zeroAttr}));
+        continue;
+      }
       paddingValueAttributes.push_back(builder.getZeroAttr(elemType));
     }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tensor_pad.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tensor_pad.mlir
index 05c0574..956b49a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tensor_pad.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tensor_pad.mlir
@@ -65,3 +65,71 @@
 // CHECK-SAME:      outs(%[[FILL]]
 // CHECK:         %{{.+}} = linalg.generic
 // CHECK-SAME:      outs(%[[MATMUL]]
+
+// -----
+
+func.func @complex_pad_for_fusion() {
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %cst = complex.constant [0.000000e+00 : f32, 0.0000000e+00 : f32] : complex<f32>
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.constant.load[0] : i32
+  %1 = hal.interface.constant.load[1] : i32
+  %2 = hal.interface.constant.load[2] : i32
+  %3 = hal.interface.constant.load[3] : i32
+  %4 = arith.index_castui %0 : i32 to index
+  %5 = arith.index_castui %1 : i32 to index
+  %6 = arith.index_castui %2 : i32 to index
+  %7 = arith.index_castui %3 : i32 to index
+  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xcomplex<f32>>>{%6, %4}
+  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xcomplex<f32>>>{%5, %7}
+  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?xcomplex<f32>>>{%6, %7}
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_count_x = hal.interface.workgroup.count[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_count_y = hal.interface.workgroup.count[1] : index
+  %11 = affine.apply affine_map<()[s0] -> (s0 * 192)>()[%workgroup_id_y]
+  %12 = affine.apply affine_map<()[s0] -> (s0 * 192)>()[%workgroup_count_y]
+  scf.for %arg0 = %11 to %6 step %12 {
+    %13 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 192)>(%arg0)[%6]
+    %14 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+    %15 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x]
+    scf.for %arg1 = %14 to %7 step %15 {
+      %16 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%arg1)[%7]
+      %17 = flow.dispatch.tensor.load %10, offsets = [%arg0, %arg1], sizes = [%13, %16], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?xcomplex<f32>>>{%6, %7} -> tensor<?x?xcomplex<f32>>
+      %18 = flow.dispatch.tensor.load %8, offsets = [%arg0, 0], sizes = [%13, %4], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xcomplex<f32>>>{%6, %4} -> tensor<?x?xcomplex<f32>>
+      %19 = flow.dispatch.tensor.load %9, offsets = [0, %arg1], sizes = [%4, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xcomplex<f32>>>{%5, %7} -> tensor<?x?xcomplex<f32>>
+      %20 = scf.for %arg2 = %c0 to %13 step %c8 iter_args(%arg3 = %17) -> (tensor<?x?xcomplex<f32>>) {
+        %21 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 8)>(%arg2)[%13]
+        %22 = scf.for %arg4 = %c0 to %16 step %c32 iter_args(%arg5 = %arg3) -> (tensor<?x?xcomplex<f32>>) {
+          %23 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 32)>(%arg4)[%16]
+          %extracted_slice = tensor.extract_slice %18[%arg2, 0] [%21, %4] [1, 1] : tensor<?x?xcomplex<f32>> to tensor<?x?xcomplex<f32>>
+          %extracted_slice_0 = tensor.extract_slice %19[0, %arg4] [%4, %23] [1, 1] : tensor<?x?xcomplex<f32>> to tensor<?x?xcomplex<f32>>
+          %extracted_slice_1 = tensor.extract_slice %arg5[%arg2, %arg4] [%21, %23] [1, 1] : tensor<?x?xcomplex<f32>> to tensor<?x?xcomplex<f32>>
+          %24 = linalg.fill ins(%cst : complex<f32>) outs(%extracted_slice_1 : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+          %25 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[192, 128, 0], [8, 32, 0], [0, 0, 16]]>} ins(%extracted_slice, %extracted_slice_0 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%24 : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+          %26 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%25 : tensor<?x?xcomplex<f32>>) {
+          ^bb0(%out: complex<f32>):
+            %27 = complex.exp %out : complex<f32>
+            linalg.yield %27 : complex<f32>
+          } -> tensor<?x?xcomplex<f32>>
+          %inserted_slice = tensor.insert_slice %26 into %arg5[%arg2, %arg4] [%21, %23] [1, 1] : tensor<?x?xcomplex<f32>> into tensor<?x?xcomplex<f32>>
+          scf.yield %inserted_slice : tensor<?x?xcomplex<f32>>
+        }
+        scf.yield %22 : tensor<?x?xcomplex<f32>>
+      }
+      flow.dispatch.tensor.store %20, %10, offsets = [%arg0, %arg1], sizes = [%13, %16], strides = [1, 1] : tensor<?x?xcomplex<f32>> -> !flow.dispatch.tensor<writeonly:tensor<?x?xcomplex<f32>>>{%6, %7}
+    }
+  }
+  return
+}
+// CHECK-LABEL: func.func @complex_pad_for_fusion
+// CHECK:         %[[PAD0:.+]] = tensor.pad
+// CHECK:         %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[PAD0]] : tensor<8x32xcomplex<f32>>
+// CHECK:         %[[PAD1:.+]] = tensor.pad
+// CHECK:         %[[PAD2:.+]] = tensor.pad
+// CHECK:         %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME:      ins(%[[PAD1]], %[[PAD2]] : tensor<8x?xcomplex<f32>>, tensor<?x32xcomplex<f32>>
+// CHECK-SAME:      outs(%[[FILL]]
+// CHECK:         %{{.+}} = linalg.generic
+// CHECK-SAME:      outs(%[[MATMUL]]
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
index 13d3eb7..4df60f0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
@@ -30,6 +30,7 @@
 
     auto originalType = op.getResult().getType();
     auto targetType = typeConverter->convertType(originalType);
+    auto targetBitwidth = IREE::Util::getTypeBitWidth(targetType);
     int32_t validByteWidth =
         IREE::Util::getRoundedElementByteWidth(originalType);
 
@@ -65,8 +66,7 @@
       auto hi = rewriter.create<arith::ShLIOp>(
           op.getLoc(),
           rewriter.create<arith::ExtUIOp>(
-              op.getLoc(),
-              rewriter.getIntegerType(targetType.getIntOrFloatBitWidth()),
+              op.getLoc(), rewriter.getIntegerType(targetBitwidth),
               hiCallOp.getResult(0)),
           rewriter.create<arith::ConstantIntOp>(op.getLoc(), 32, 32));
 
@@ -75,8 +75,7 @@
           ArrayRef<Value>{adaptor.getSourceBuffer(), sourceOffset,
                           halfByteWidth});
       auto lo = rewriter.create<arith::ExtUIOp>(
-          op.getLoc(),
-          rewriter.getIntegerType(targetType.getIntOrFloatBitWidth()),
+          op.getLoc(), rewriter.getIntegerType(targetBitwidth),
           loCallOp.getResult(0));
 
       value = rewriter.create<arith::OrIOp>(op.getLoc(), lo, hi);