Enable matmul to mmt4d transformation for all types (not just f32) (#7477)

Background: earlier I attempted to make mixed types work in vector.contract lowerings, see https://reviews.llvm.org/D112508 . See the closing comment there explaining the approach and why we abandoned it in favor of promoting inputs to the destination element type in vector.contract.

Another minor cleanup is folded into this PR: we are dropping the flag --iree-codegen-vectorize-linalg-mmt4d from custom iree-opt flags in the build rules for e2e matmul tests, because this pass is already enabled by default in iree-translate.
diff --git a/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp b/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp
index 34340ef..c7db560 100644
--- a/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp
+++ b/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp
@@ -14,6 +14,23 @@
 
 namespace {
 
+Value promoteVector(Location loc, Value inputVector, Type promotedElementType,
+                    PatternRewriter &rewriter) {
+  VectorType inputVectorType = inputVector.getType().cast<VectorType>();
+  if (inputVectorType.getElementType() == promotedElementType) {
+    return inputVector;
+  } else {
+    auto promotedVectorType = inputVectorType.clone(promotedElementType);
+    if (promotedElementType.isIntOrIndex()) {
+      return rewriter.create<arith::ExtSIOp>(loc, inputVector,
+                                             promotedVectorType);
+    } else {
+      return rewriter.create<arith::ExtFOp>(loc, inputVector,
+                                            promotedVectorType);
+    }
+  }
+}
+
 /// Converts linalg.mmt4d into vector.contract.
 /// This converts linalg.mmt4d with operands <1x1xM0xK0>, <1x1xK0xN0>
 /// to vector.contract where K0 is the contraction dimension.
@@ -22,12 +39,13 @@
 
   LogicalResult matchAndRewrite(linalg::Mmt4DOp mmt4DOp,
                                 PatternRewriter &rewriter) const override {
-    auto lhs = mmt4DOp.inputs()[0];
-    auto rhs = mmt4DOp.inputs()[1];
-    auto dst = mmt4DOp.outputs()[0];
+    Value lhs = mmt4DOp.inputs()[0];
+    Value rhs = mmt4DOp.inputs()[1];
+    Value dst = mmt4DOp.outputs()[0];
 
-    auto lhsType = lhs.getType().dyn_cast<ShapedType>();
-    auto rhsType = rhs.getType().dyn_cast<ShapedType>();
+    ShapedType lhsType = lhs.getType().dyn_cast<ShapedType>();
+    ShapedType rhsType = rhs.getType().dyn_cast<ShapedType>();
+    ShapedType dstType = dst.getType().dyn_cast<ShapedType>();
 
     // This pattern expects tensors of static shapes.
     // In practice, dynamic shapes are meant to be handled by other passes,
@@ -55,16 +73,20 @@
     int N0 = rhsType.getShape()[2];
     int K0 = lhsType.getShape()[3];
 
-    auto loc = mmt4DOp.getLoc();
-    auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Location loc = mmt4DOp.getLoc();
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
 
-    auto lhsVecType = VectorType::get({1, 1, M0, K0}, rewriter.getF32Type());
-    auto rhsVecType = VectorType::get({1, 1, N0, K0}, rewriter.getF32Type());
-    auto dstVecType = VectorType::get({1, 1, M0, N0}, rewriter.getF32Type());
+    Type lhsElementType = lhsType.getElementType();
+    Type rhsElementType = rhsType.getElementType();
+    Type dstElementType = dstType.getElementType();
 
-    auto lhsVecType2D = VectorType::get({M0, K0}, rewriter.getF32Type());
-    auto rhsVecType2D = VectorType::get({N0, K0}, rewriter.getF32Type());
-    auto dstVecType2D = VectorType::get({M0, N0}, rewriter.getF32Type());
+    auto lhsVecType = VectorType::get({1, 1, M0, K0}, lhsElementType);
+    auto rhsVecType = VectorType::get({1, 1, N0, K0}, rhsElementType);
+    auto dstVecType = VectorType::get({1, 1, M0, N0}, dstElementType);
+
+    auto lhsVecType2D = VectorType::get({M0, K0}, lhsElementType);
+    auto rhsVecType2D = VectorType::get({N0, K0}, rhsElementType);
+    auto dstVecType2D = VectorType::get({M0, N0}, dstElementType);
 
     auto identityMap = rewriter.getMultiDimIdentityMap(4);
 
@@ -84,6 +106,14 @@
     Value dstVec2D =
         rewriter.create<vector::ShapeCastOp>(loc, dstVecType2D, dstVec);
 
+    // Promote, if needed, the element type in the lhs and rhs vectors to
+    // match the dst vector, so that the vector.contract below will involve
+    // only one element type. This is in line with planned design, see
+    // the closing comment on https://reviews.llvm.org/D112508 where the
+    // alternative of using mixed types was considered.
+    Value promLhsVec2d = promoteVector(loc, lhsVec2D, dstElementType, rewriter);
+    Value promRhsVec2d = promoteVector(loc, rhsVec2D, dstElementType, rewriter);
+
     // Generate the vector.contract on 2D vectors replacing the mmt4d op.
     auto m = rewriter.getAffineDimExpr(0);
     auto n = rewriter.getAffineDimExpr(1);
@@ -96,7 +126,7 @@
         {getParallelIteratorTypeName(), getParallelIteratorTypeName(),
          getReductionIteratorTypeName()});
     Value contractResult = rewriter.create<vector::ContractionOp>(
-        loc, lhsVec2D, rhsVec2D, dstVec2D, indexingMaps, iterators);
+        loc, promLhsVec2d, promRhsVec2d, dstVec2D, indexingMaps, iterators);
 
     // Convert the output vector from 2D shape (M0xN0) to 4D shape (1x1xM0xN0)
     Value contractResult4D =
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
index 1a103f8..7e6cecf 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
@@ -112,11 +112,6 @@
       return failure();
     }
 
-    // This is for float only matmul for now. Integer data type might require
-    // r.h.s layout change.
-    if (!lhsType.getElementType().isF32() || !rhsType.getElementType().isF32())
-      return failure();
-
     int m = lhsType.getShape()[0];
     int k = rhsType.getShape()[0];
     int n = rhsType.getShape()[1];
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index 77e48c3..71f6df5 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -131,7 +131,6 @@
     ],
     opt_flags = [
         "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
-        "--iree-codegen-vectorize-linalg-mmt4d",
     ],
     target_backends_and_drivers = [
         ("dylib-llvm-aot", "dylib"),
@@ -152,7 +151,6 @@
     ],
     opt_flags = [
         "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
-        "--iree-codegen-vectorize-linalg-mmt4d",
     ],
     target_backends_and_drivers = [
         ("dylib-llvm-aot", "dylib"),
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index 795a874..309f907 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -176,7 +176,6 @@
     "vmvx"
   OPT_FLAGS
     "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
-    "--iree-codegen-vectorize-linalg-mmt4d"
 )
 
 iree_generated_trace_runner_test(
@@ -197,7 +196,6 @@
     "vmvx"
   OPT_FLAGS
     "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
-    "--iree-codegen-vectorize-linalg-mmt4d"
 )
 
 iree_generated_trace_runner_test(
@@ -216,7 +214,6 @@
     "dylib"
   OPT_FLAGS
     "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
-    "--iree-codegen-vectorize-linalg-mmt4d"
 )
 
 iree_generated_trace_runner_test(
@@ -235,7 +232,6 @@
     "dylib"
   OPT_FLAGS
     "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
-    "--iree-codegen-vectorize-linalg-mmt4d"
 )
 
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###