Merge pull request #3048 from GMNGeoffrey:main-to-google

PiperOrigin-RevId: 329558845
diff --git a/colab/edge_detection.ipynb b/colab/edge_detection.ipynb
index eaa26b6..18528e4 100644
--- a/colab/edge_detection.ipynb
+++ b/colab/edge_detection.ipynb
@@ -338,11 +338,10 @@
         "\n",
         "Overview:\n",
         "\n",
-        "1.  Save the `tf.Module` as a `SavedModel`\n",
-        "2.  Use IREE's python bindings to load the `SavedModel` into MLIR in the `mhlo` dialect\n",
-        "3.  Save the MLIR to a file (can stop here to use it from another application)\n",
-        "4.  Compile the `mhlo` MLIR into a VM module for IREE to execute\n",
-        "5.  Run the VM module through IREE's runtime to test the edge detection function"
+        "1.  Convert the `tf.Module` into an IREE compiler module\n",
+        "2.  Save the MLIR assembly from the module into a file (can stop here to use it from another application)\n",
+        "3.  Compile the `mhlo` MLIR into a VM module for IREE to execute\n",
+        "4.  Run the VM module through IREE's runtime to test the edge detection function"
       ]
     },
     {
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
index 1b5291c..9c8a1f6 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -135,8 +135,8 @@
   Value vSubgroupY = b.create<ConstantIndexOp>(loc, numSubgroupY);
   SmallVector<linalg::ProcInfo, 2> procInfo(2);
   using namespace edsc::op;
-  procInfo[0] = {sg % vSubgroupX, vSubgroupX};
-  procInfo[1] = {sgdiv % vSubgroupY, vSubgroupY};
+  procInfo[0] = {sgdiv % vSubgroupY, vSubgroupY};
+  procInfo[1] = {sg % vSubgroupX, vSubgroupX};
   return procInfo;
 }
 
@@ -325,7 +325,7 @@
               linalg::LinalgTilingOptions()
                   .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
                   .setTileSizes(
-                      {tileN / numSubgroupY, tileM / numSubgroupX, tileK})
+                      {tileM / numSubgroupY, tileN / numSubgroupX, tileK})
                   .setDistributionOptions(SGDistribute));
     }
     strategy.vectorize<linalg::MatmulOp>().unrollVector<vector::ContractionOp>(
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.cpp b/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.cpp
index b5e9da8..a656bcd 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.cpp
@@ -14,7 +14,9 @@
 #include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h"
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 
 using namespace mlir;
@@ -50,7 +52,8 @@
 }
 
 bool supportsCooperativeMatrix(Operation* op) {
-  if (isa<vector::TransferReadOp>(op) || isa<vector::TransferWriteOp>(op))
+  if (isa<vector::TransferReadOp, vector::TransferWriteOp, scf::ForOp,
+          scf::YieldOp>(op))
     return true;
   if (isa<vector::ContractionOp>(op) &&
       isLegalVectorContract(cast<vector::ContractionOp>(op)))
@@ -75,12 +78,14 @@
     auto contract = dyn_cast<vector::ContractionOp>(op);
     if (contract == nullptr) return;
     auto hasVectorDest = [](Operation* op) {
+      if (isa<ConstantOp, AllocOp>(op)) return false;
       for (auto resultType : op->getResultTypes()) {
         if (resultType.isa<VectorType>()) return true;
       }
+      if (op->getNumResults() == 0) return true;
       return false;
     };
-    auto dependentOps = getSlice(op, hasVectorDest);
+    auto dependentOps = getSlice(op, hasVectorDest, hasVectorDest);
     for (auto* dependeOp : dependentOps) {
       // If any instruction cannot use cooperative matrix drop the whole chaine.
       // In the future we can introduce "bitcast" type of conversion to allow
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index eb4f098..9737bb6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -66,3 +66,41 @@
     return
   }
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, CooperativeMatrixNV, Int8, Float16, StorageUniform16, StorageBuffer8BitAccess, Float16Buffer], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix, SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @kernel_matmul_licm(%arg0: memref<4096x4096xi8>, %arg1: memref<4096x4096xi8>, %arg2: memref<4096x4096xi32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+    %c32 = constant 32 : index
+    %c4096 = constant 4096 : index
+    %c0 = constant 0 : index
+    %c0_i32 = constant 0 : i32
+    %c0_i8 = constant 0 : i8
+    // CHECK: %[[C:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+    %4 = vector.transfer_read %arg2[%c0, %c0], %c0_i32 {masked = [false, false]} : memref<4096x4096xi32>, vector<16x16xi32>
+    // CHECK: %[[ACC:.+]] = spv.Variable : !spv.ptr<!spv.coopmatrix<16x16xi32, Subgroup>, Function>
+    // CHECK: spv.loop {
+      // CHECK: spv.Branch ^[[BB:.+]](%{{.*}}, %[[C]] : i32, !spv.coopmatrix<16x16xi32, Subgroup>)
+      // CHECK: ^[[BB]](%{{.*}}: i32, %[[C1:.+]]: !spv.coopmatrix<16x16xi32, Subgroup>)
+    %5 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %4) -> (vector<16x16xi32>) {
+      // CHECK: %[[A:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+      %6 = vector.transfer_read %arg0[%c0, %arg3], %c0_i8 {masked = [false, false]} : memref<4096x4096xi8>, vector<16x32xi8>
+      // CHECK: %[[B:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+      %7 = vector.transfer_read %arg1[%arg3, %c0], %c0_i8 {masked = [false, false]} : memref<4096x4096xi8>, vector<32x16xi8>
+      // CHECK: %[[R:.+]] = spv.CooperativeMatrixMulAddNV %[[A]], %[[B]], %[[C1]]
+      %8 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %6, %7, %arg4 : vector<16x32xi8>, vector<32x16xi8> into vector<16x16xi32>
+      // CHECK: spv.Store "Function" %[[ACC]], %[[R]] : !spv.coopmatrix<16x16xi32, Subgroup>
+      // CHECK: spv.Branch ^[[BB]](%{{.*}}, %[[R]] : i32, !spv.coopmatrix<16x16xi32, Subgroup>)
+      scf.yield %8 : vector<16x16xi32>
+    }
+    // CHECK: %[[ACCv:.+]] = spv.Load "Function" %[[ACC]] : !spv.coopmatrix<16x16xi32, Subgroup>
+    // CHECK: spv.CooperativeMatrixStoreNV %{{.*}}, %[[ACCv]], %{{.*}}, %{{.*}}
+    vector.transfer_write %5, %arg2[%c0, %c0] : vector<16x16xi32>, memref<4096x4096xi32>
+    return
+  }
+}