[ROCM] Add zero fill check to ukernel patterns (#21793)

Add missing checks for zero fill to the PDL ukernel patterns as this is
assumed by the implementations.

Also sets the correct corresponding workgroup tile sizes for f16 large
expanded ukernel.

Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
index 18c5013..6ddcba2 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
@@ -50,6 +50,29 @@
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_f8_medium_no_zero_fill(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x128x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x128x1024xf32>) {
+    ^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
+      %12 = arith.extf %in : f8E4M3FNUZ to f32
+      %13 = arith.extf %in_4 : f8E4M3FNUZ to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1x128x1024xf32>
+  return %2 : tensor<1x128x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f8_medium_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
 // Through a constraint, the inner dimension is known to be a multiple of 128 and has a lower bound of 512, so should be matched.
 
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -136,6 +159,29 @@
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_f8_large_no_zero_fill(%arg0: tensor<1x256x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x256x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x256x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x256x1024xf32>) {
+    ^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
+      %12 = arith.extf %in : f8E4M3FNUZ to f32
+      %13 = arith.extf %in_4 : f8E4M3FNUZ to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1x256x1024xf32>
+  return %2 : tensor<1x256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f8_large_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -162,6 +208,75 @@
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_f16_no_zero_fill(%arg0: tensor<1024x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1024x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1024x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1024x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1024x1024xf32>) {
+    ^bb0(%in: f16, %in_4: f16, %out: f32):
+      %12 = arith.extf %in : f16 to f32
+      %13 = arith.extf %in_4 : f16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1024x1024xf32>
+  return %2 : tensor<1024x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f16_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_f16_medium_expanded_no_zero_fill(%arg0: tensor<1x128x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1x128x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x128x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1x128x1024xf32>) {
+    ^bb0(%in: f16, %in_4: f16, %out: f32):
+      %12 = arith.extf %in : f16 to f32
+      %13 = arith.extf %in_4 : f16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1x128x1024xf32>
+  return %2 : tensor<1x128x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f16_medium_expanded_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_f16_large_expanded_no_zero_fill(%arg0: tensor<1x256x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1x256x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x256x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1x256x1024xf32>) {
+    ^bb0(%in: f16, %in_4: f16, %out: f32):
+      %12 = arith.extf %in : f16 to f32
+      %13 = arith.extf %in_4 : f16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1x256x1024xf32>
+  return %2 : tensor<1x256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f16_large_expanded_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
 func.func @negative_matmul_bf16(%arg0: tensor<256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<256x1024xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<256x1024xf32>
@@ -182,6 +297,29 @@
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_bf16_large_no_zero_fill(%arg0: tensor<1024x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1024x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1024x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1024x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1024x1024xf32>) {
+    ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+      %12 = arith.extf %in : bf16 to f32
+      %13 = arith.extf %in_4 : bf16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1024x1024xf32>
+  return %2 : tensor<1024x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16_large_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
 // The dynamic dimension is a multiple of 256, but doesn't have a lower bound of 256.
 
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -207,3 +345,49 @@
 // CHECK-LABEL: @negative_matmul_bf16_dynamic_lower_bound
 // CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
 // CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_bf16_expanded_medium_no_zero_fill(%arg0: tensor<1x128x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1x128x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x128x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1x128x1024xf32>) {
+    ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+      %12 = arith.extf %in : bf16 to f32
+      %13 = arith.extf %in_4 : bf16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1x128x1024xf32>
+  return %2 : tensor<1x128x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16_expanded_medium_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_bf16_expanded_large_no_zero_fill(%arg0: tensor<1x256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1x256x1024xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x256x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1x256x1024xf32>) {
+    ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+      %12 = arith.extf %in : bf16 to f32
+      %13 = arith.extf %in_4 : bf16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<1x256x1024xf32>
+  return %2 : tensor<1x256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16_expanded_large_no_zero_fill
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
index c6c83b4..ce24d04 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
@@ -15,13 +15,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
-  // Match the a matmul-like generic with above indexin maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -94,13 +101,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -180,13 +194,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -256,13 +277,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -339,13 +367,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -385,7 +420,7 @@
 
     %config_name = pdl.attribute = "compilation_info"
     %config = pdl.attribute = #iree_codegen.compilation_info<
-      lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 256, 256, 0]}>,
       translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
         workgroup_size = [512, 1, 1] subgroup_size = 64,
         // This strategy uses the maximum amount of possible shared memory on
@@ -422,13 +457,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -498,13 +540,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -583,13 +632,20 @@
   %lhs_type = pdl.type
   %rhs_type = pdl.type
   %out_type = pdl.type
+  %zero_type = pdl.type : f32
 
   %lhs = pdl.operand : %lhs_type
   %rhs = pdl.operand : %rhs_type
   %out_init = pdl.operand : %out_type
 
+  %zero_val = pdl.attribute = 0. : f32
+  %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%zero_type : !pdl.type)
+  %zero = pdl.result 0 of %zero_op
+  %fill_op = pdl.operation "linalg.fill" (%zero, %out_init : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %fill = pdl.result 0 of %fill_op
+
   // Match the a matmul-like generic with above indexing maps.
-  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %fill : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)