[Codegen] Fix bug in IGEMM pass for non conv contractions (#18838)
Adds back a match condition in the ConvolutionToIGEMM pass that got lost
in code cleanup. Checks that the im2col op producer exists, and adds a
test for the failing case.
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
index bcb5fe5..58b678c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
@@ -45,6 +45,9 @@
break;
}
}
+ if (!im2colOp) {
+ return rewriter.notifyMatchFailure(genericOp, "no im2colOp producer.");
+ }
if (getLoweringConfig(genericOp)) {
return rewriter.notifyMatchFailure(genericOp,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
index ad1d552..3d5494e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
@@ -85,3 +85,23 @@
// CHECK-NOT: iree_linalg_ext.im2col
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: lowering_config
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func public @no_conv_contraction(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<128x128xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %matmul = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<128x128xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %0, %out : f32
+ linalg.yield %1 : f32
+ } -> tensor<128x128xf32>
+ return %matmul : tensor<128x128xf32>
+}
+// CHECK: func.func public @no_conv_contraction
+// CHECK: linalg.generic