Add folding arithmetic extensions (#15953)

Add folding arithmetic extensions on floating point data types into
vector contraction operations
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index a46aad8..f6d59a5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -43,6 +43,7 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -231,6 +232,11 @@
       patterns, [](Operation *op) { return true; });
 }
 
+void transform_dialect::ApplyFoldArithExtIntoContractionOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateFoldArithExtensionPatterns(patterns);
+}
+
 void transform_dialect::ApplyFoldReshapeIntoTensorHalInterfacePatternsOp::
     populatePatterns(RewritePatternSet &patterns) {
   populateReshapeToInterfaceTensorPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 69cec35..041211a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -57,6 +57,20 @@
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldArithExtIntoContractionOp : Op<Transform_Dialect,
+    "apply_patterns.iree.fold_arith_ext_into_contraction",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Populate pattern to fold arithmetic extensions on floating point data types into
+    vector contraction operations. linalg.matmul introduces arithmetic
+    extensions on its operands.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyFoldFillIntoPadPatternsOp : Op<Transform_Dialect,
     "apply_patterns.iree.fold_fill_into_pad",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 2e6bcd9..d55b1c9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -62,6 +62,7 @@
             "test_partitionable_loops_interface.mlir",
             "tile_and_distribute_to_workgroups.mlir",
             "transform_buffer_opt.mlir",
+            "transform_fold_arith_extf_into_vector_contract.mlir",
             "transform_match_partial_reduction.mlir",
             "transform_ops_invalid.mlir",
             "transpose_canonicalization.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 2cbd101..8cc7325 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -58,6 +58,7 @@
     "test_partitionable_loops_interface.mlir"
     "tile_and_distribute_to_workgroups.mlir"
     "transform_buffer_opt.mlir"
+    "transform_fold_arith_extf_into_vector_contract.mlir"
     "transform_match_partial_reduction.mlir"
     "transform_ops_invalid.mlir"
     "transpose_canonicalization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_fold_arith_extf_into_vector_contract.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_fold_arith_extf_into_vector_contract.mlir
new file mode 100644
index 0000000..0bd9e48
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_fold_arith_extf_into_vector_contract.mlir
@@ -0,0 +1,28 @@
+// RUN: iree-opt -iree-transform-dialect-interpreter %s | FileCheck %s
+
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @fold_arith_extf_into_contract
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>)
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
+//  CHECK-NEXT:   return %[[R]] : vector<64x64xf32>
+func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
+    %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
+    %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
+    %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
+    return %result : vector<64x64xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.iree.fold_arith_ext_into_contraction
+    } : !transform.any_op
+    transform.yield
+  } // @__transform_main
+} // module