Remove Upcasting schedule from TileAndFuse (#19669)

We cant do this e2e but also dont want to support this as it plays loose
with numerics. We will bail out instead.
Fixes : https://github.com/iree-org/iree/issues/19532

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 3c514c5..61c325c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -165,13 +165,6 @@
       problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
       transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
       /*mustBeAligned*/ mustBeAligned, doCPromotion);
-  if (!schedule) {
-    // Then try again by allowing upcasting accumulator.
-    schedule = deduceMMASchedule(
-        problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
-        transposedLhs, transposedRhs, /*canUpcastAcc=*/true,
-        /*mustBeAligned*/ mustBeAligned, doCPromotion);
-  }
   return schedule;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 47ccb62..3d137e5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -10,21 +10,23 @@
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
-func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf16> {
+func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf32> {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f16
-  %5 = tensor.empty() : tensor<2x10x64x64xf16>
-  %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
+  %cst = arith.constant 0.000000e+00 : f32
+  %5 = tensor.empty() : tensor<2x10x64x64xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x10x64x64xf32>) -> tensor<2x10x64x64xf32>
   %7 = linalg.generic {
     indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
-    ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
-  ^bb0(%in: f16, %in_0: f16, %out: f16):
-    %8 = arith.mulf %in, %in_0 : f16
-    %9 = arith.addf %8, %out : f16
-    linalg.yield %9 : f16
-  } -> tensor<2x10x64x64xf16>
-  return %7 : tensor<2x10x64x64xf16>
+    ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf32>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f32):
+    %8 = arith.extf %in : f16 to f32
+    %9 = arith.extf %in_0 : f16 to f32
+    %10 = arith.mulf %8, %9 : f32
+    %11 = arith.addf %10, %out : f32
+    linalg.yield %11 : f32
+  } -> tensor<2x10x64x64xf32>
+  return %7 : tensor<2x10x64x64xf32>
 }
 
 // CHECK-LABEL: func.func @expanded_matmul_transpose_b
@@ -46,21 +48,23 @@
 #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> {
+func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf32> {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f16
-  %5 = tensor.empty() : tensor<10x4x32x32xf16>
-  %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16>
+  %cst = arith.constant 0.000000e+00 : f32
+  %5 = tensor.empty() : tensor<10x4x32x32xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<10x4x32x32xf32>) -> tensor<10x4x32x32xf32>
   %7 = linalg.generic {
     indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
-    ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) {
-  ^bb0(%in: f16, %in_0: f16, %out: f16):
-    %8 = arith.mulf %in, %in_0 : f16
-    %9 = arith.addf %8, %out : f16
-    linalg.yield %9 : f16
-  } -> tensor<10x4x32x32xf16>
-  return %7 : tensor<10x4x32x32xf16>
+    ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf32>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f32):
+    %8 = arith.extf %in : f16 to f32
+    %9 = arith.extf %in_0 : f16 to f32
+    %10 = arith.mulf %8, %9 : f32
+    %11 = arith.addf %10, %out : f32
+    linalg.yield %11 : f32
+  } -> tensor<10x4x32x32xf32>
+  return %7 : tensor<10x4x32x32xf32>
 }
 
 // CHECK-LABEL: func.func @multi_dim_mma_schedule
@@ -79,23 +83,25 @@
 #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf16> {
+func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf32> {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f16
+  %cst = arith.constant 0.000000e+00 : f32
   %d0 = tensor.dim %lhs, %c0 : tensor<?x6x16x?x16xf16>
   %d2 = tensor.dim %rhs, %c0 : tensor<?x32x?x16xf16>
-  %5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf16>
-  %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<?x6x?x16x32xf16>) -> tensor<?x6x?x16x32xf16>
+  %5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x6x?x16x32xf32>) -> tensor<?x6x?x16x32xf32>
   %7 = linalg.generic {
     indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
-    ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf16>) {
-  ^bb0(%in: f16, %in_0: f16, %out: f16):
-    %8 = arith.mulf %in, %in_0 : f16
-    %9 = arith.addf %8, %out : f16
-    linalg.yield %9 : f16
-  } -> tensor<?x6x?x16x32xf16>
-  return %7 : tensor<?x6x?x16x32xf16>
+    ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf32>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f32):
+    %8 = arith.extf %in : f16 to f32
+    %9 = arith.extf %in_0 : f16 to f32
+    %10 = arith.mulf %8, %9 : f32
+    %11 = arith.addf %10, %out : f32
+    linalg.yield %11 : f32
+  } -> tensor<?x6x?x16x32xf32>
+  return %7 : tensor<?x6x?x16x32xf32>
 }
 
 // CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule