[LLVMGPU] Generalize VectorContractOpInfo based on indexing maps (#17625)

This patch generalizes VectorContractOpInfo to work on any kind of
vector.contract. The "kind" field was not being used by any pass anyway
and they were relying on m, n, k dims.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index 910c5ad..a52ead0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -67,9 +67,6 @@
 
     // Infer the contract kind so that we know know to correlate M/N/K dims.
     VectorContractOpInfo opDetail(contractOp);
-    if (opDetail.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
-      return rewriter.notifyMatchFailure(contractOp, "unknown contract kind");
-    }
 
     SmallVector<int64_t> distShape = resultLayout.getDistributedShape();
     LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index d4fd9fc..5df8a36 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -833,9 +833,8 @@
     llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
     llvm::errs() << "For schedule: " << *this << "\n";
   });
-  if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
-    LLVM_DEBUG({ llvm::errs() << "Unknown contraction kind\n"; });
-    return failure();
+  if (opInfo.getKDims().size() != 1) {
+    return contractOp->emitError("Unimplemented: > 1 k dims");
   }
 
   auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index 3aead0a..84725a0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -32,9 +32,6 @@
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
     VectorContractOpInfo opInfo(contractOp);
-    if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
-      return rewriter.notifyMatchFailure(contractOp, "unhandled contract kind");
-    }
 
     auto srcCType = dyn_cast<VectorType>(contractOp.getAccType());
     if (!srcCType) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index 6dfaf38..a3d13d0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -54,6 +54,34 @@
 // -----
 
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64,
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+func.func @mfma_matmul_96x64x16_mmtt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<64x96xf32>) -> vector<64x96xf32> attributes { translation_info = #translation } {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, k) -> (n, m)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<64x96xf32>
+  return %0 : vector<64x96xf32>
+}
+
+//      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
+//      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
+//      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 3], outers_per_batch = [1, 4], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
+// CHECK-SAME: subgroup_order = [1, 0], thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
                                               workgroup_size = [64, 2, 1]
                                               subgroup_size = 64,
       {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 2, subgroup_n_count = 1>}>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
index 78007c6..0bf721c 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
@@ -26,33 +26,11 @@
   return std::make_pair(outMDims.back(), outNDims.back());
 }
 
-VectorContractOpInfo::OpKind
-VectorContractOpInfo::inferOpKind(MLIRContext *ctx,
-                                  SmallVector<AffineMap> maps) {
-  if (contractionDims.k.size() != 1) {
-    return OpKind::UNKNOWN;
-  }
-  if (!contractionDims.batch.empty()) {
-    if (contractionDims.batch.size() > 1 || contractionDims.batch[0] != 0) {
-      return OpKind::UNKNOWN;
-    }
-    if (*maps[0].getResultPosition(getAffineDimExpr(0, ctx)) != 0 ||
-        *maps[1].getResultPosition(getAffineDimExpr(0, ctx)) != 0 ||
-        *maps[2].getResultPosition(getAffineDimExpr(0, ctx)) != 0) {
-      return OpKind::UNKNOWN;
-    }
-  }
+VectorContractOpInfo::VectorContractOpInfo(vector::ContractionOp op) {
+  contractionDims = *linalg::inferContractionDims(op.getIndexingMapsArray());
 
-  int64_t innerM = contractionDims.m.back();
-  int64_t innerN = contractionDims.n.back();
-  int64_t k = contractionDims.k.back();
-
-  int64_t lhsM = *maps[0].getResultPosition(getAffineDimExpr(innerM, ctx));
-  lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
-  int64_t rhsN = *maps[1].getResultPosition(getAffineDimExpr(innerN, ctx));
-  rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
-  int64_t outM = *maps[2].getResultPosition(getAffineDimExpr(innerM, ctx));
-  int64_t outN = *maps[2].getResultPosition(getAffineDimExpr(innerN, ctx));
+  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
+  MLIRContext *ctx = op.getContext();
 
   for (auto m : contractionDims.m) {
     lhsMDims.push_back(*maps[0].getResultPosition(getAffineDimExpr(m, ctx)));
@@ -63,15 +41,9 @@
     outNDims.push_back(*maps[2].getResultPosition(getAffineDimExpr(n, ctx)));
   }
 
-  if (outM < outN) {
-    if (lhsM < lhsKDim) {
-      if (rhsN < rhsKDim) {
-        return OpKind::MK_NK_MN;
-      }
-      return OpKind::MK_KN_MN;
-    }
-  }
-  return OpKind::UNKNOWN;
+  int64_t k = contractionDims.k.back();
+  lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
+  rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
index 64cfc0a..f8e0cf9 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
@@ -6,21 +6,13 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
 
 namespace mlir::iree_compiler {
 
 /// A class for querying information about a contract op.
 class VectorContractOpInfo {
 public:
-  enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN };
-
-  explicit VectorContractOpInfo(vector::ContractionOp op) {
-    contractionDims = *linalg::inferContractionDims(op.getIndexingMapsArray());
-    opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray());
-  }
-
-  OpKind getOpKind() const { return opKind; }
+  explicit VectorContractOpInfo(vector::ContractionOp op);
 
   // Returns the (LHS M, RHS N) dimension index pair.
   std::pair<int, int> getOperandMNIndex() const;
@@ -58,11 +50,6 @@
   SmallVector<int64_t> outNDims;
 
 private:
-  // Gets the kind of a contract op with the given indexing |maps|.
-  OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps);
-
-  OpKind opKind = OpKind::UNKNOWN;
-
   linalg::ContractionDimensions contractionDims;
 };