Disable folding casting ops into contraction ops by default. (#15342)
This is only required by LLVMGPU backend, the behavior was accidentally
changed during refactoring. The revision turns it back to the old
behavior.
Fixes https://github.com/openxla/iree/issues/15241
diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
index eeaaa16..b164dfc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
@@ -175,6 +175,7 @@
this->vectorizeGatherAccesses.setValue(options.vectorizeGatherAccesses);
this->enableCleanup.setValue(options.enableCleanup);
this->generateContract.setValue(options.generateContract);
+ this->foldCastIntoContract.setValue(options.foldCastIntoContract);
this->maxVectorSize.setValue(options.maxVectorSize);
}
@@ -253,6 +254,8 @@
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorizationPatterns);
vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
+ }
+ if (foldCastIntoContract) {
vector::populateFoldArithExtensionPatterns(vectorizationPatterns);
}
if (enableVectorMasking) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index eedd9e3..2eae4a8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -154,6 +154,9 @@
bool enableCleanup = true;
// Enable conversion for reduction ops to contraction ops.
bool generateContract = true;
+ // Enable folding casting ops into contraction ops. Note that the resulting
+ // mixed-type contraction ops are only handled by certain backends.
+ bool foldCastIntoContract = false;
// Max vector size allowed to avoid creating large vectors.
int64_t maxVectorSize = std::numeric_limits<int64_t>::max();
};
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 9abfe85..477ff82 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -269,6 +269,8 @@
"generated from tiling so it affects later steps like bufferization and vector hoisting.">,
Option<"generateContract", "generate-contract", "bool",/*default=*/"true",
"Enable conversion for reduction ops to contraction ops.">,
+ Option<"foldCastIntoContract", "fold-cast-into-contract", "bool",/*default=*/"false",
+ "Enable folding casting ops into vector.contract.">,
Option<"maxVectorSize", "max-vector-size", "int64_t",
/*default=*/"2147483647",
"Max vector size allowed to avoid creating large vectors.">
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 97a04a6..69cd467 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -42,6 +42,7 @@
"fold_affine_min_of_block_id.mlir",
"fold_tensor_extract_op.mlir",
"forop_canonicalization.mlir",
+ "generic_vectorization.mlir",
"hoist_statically_bound_allocations.mlir",
"hoist_unrolled_vector_extract_insert_slice.mlir",
"iree_comprehensive_bufferize.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 3e839a1..2569a72 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -38,6 +38,7 @@
"fold_affine_min_of_block_id.mlir"
"fold_tensor_extract_op.mlir"
"forop_canonicalization.mlir"
+ "generic_vectorization.mlir"
"hoist_statically_bound_allocations.mlir"
"hoist_unrolled_vector_extract_insert_slice.mlir"
"iree_comprehensive_bufferize.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization.mlir b/compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization.mlir
new file mode 100644
index 0000000..c45e492
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization.mlir
@@ -0,0 +1,26 @@
+// RUN: iree-opt --iree-codegen-generic-vectorization --split-input-file %s | FileCheck %s
+// RUN: iree-opt --iree-codegen-generic-vectorization="fold-cast-into-contract=true" --split-input-file %s | FileCheck %s -check-prefix=CHECK-FOLD
+
+func.func @matmul(%lhs: tensor<3x4xf16>, %rhs: tensor<4x5xf16>, %acc: tensor<3x5xf32>) -> tensor<3x5xf32> {
+ %result = linalg.matmul ins(%lhs, %rhs: tensor<3x4xf16>, tensor<4x5xf16>) outs(%acc: tensor<3x5xf32>) -> tensor<3x5xf32>
+ return %result: tensor<3x5xf32>
+}
+// CHECK-LABEL: func.func @matmul
+// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]
+// CHECK: %[[LHS_VEC:.+]] = vector.transfer_read %[[LHS]]
+// CHECK: %[[RHS_VEC:.+]] = vector.transfer_read %[[RHS]]
+// CHECK: %[[OUT_VEC:.+]] = vector.transfer_read %[[OUT]]
+// CHECK: %[[EXT_LHS:.+]] = arith.extf %[[LHS_VEC]]
+// CHECK: %[[EXT_RHS:.+]] = arith.extf %[[RHS_VEC]]
+// CHECK: %[[RES:.+]] = vector.contract {{.+}} %[[EXT_LHS]], %[[EXT_RHS]], %[[OUT_VEC]]
+
+// CHECK-FOLD-LABEL: func.func @matmul
+// CHECK-FOLD-SAME: %[[LHS:[a-zA-Z0-9]+]]
+// CHECK-FOLD-SAME: %[[RHS:[a-zA-Z0-9]+]]
+// CHECK-FOLD-SAME: %[[OUT:[a-zA-Z0-9]+]]
+// CHECK-FOLD: %[[LHS_VEC:.+]] = vector.transfer_read %[[LHS]]
+// CHECK-FOLD: %[[RHS_VEC:.+]] = vector.transfer_read %[[RHS]]
+// CHECK-FOLD: %[[OUT_VEC:.+]] = vector.transfer_read %[[OUT]]
+// CHECK-FOLD: %[[RES:.+]] = vector.contract {{.+}} %[[LHS_VEC]], %[[RHS_VEC]], %[[OUT_VEC]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 556b36d..6ea6a99 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -110,6 +110,7 @@
options.vectorizePadding = true;
options.vectorizeGatherAccesses = true;
options.enableCleanup = false;
+ options.foldCastIntoContract = true;
options.maxVectorSize = 4096;
pm.addNestedPass<func::FuncOp>(createGenericVectorizationPass(options));
pm.addNestedPass<func::FuncOp>(createHoistRedundantVectorTransfersPass());