Fix math.ctlz on vulkan and spirv (#9195)
Make use of the new expansion for ctlz for platforms that do not support it as
an intrinsic.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD
index 496f44f..f22c7e7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD
@@ -76,6 +76,7 @@
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathToSPIRV",
+ "@llvm-project//mlir:MathTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToSPIRV",
"@llvm-project//mlir:MemRefTransforms",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 91ac562..c23f954 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -63,6 +63,7 @@
MLIRLinalg
MLIRLinalgTransforms
MLIRMathToSPIRV
+ MLIRMathTransforms
MLIRMemRef
MLIRMemRefToSPIRV
MLIRMemRefTransforms
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 1935e1e..d2995fc 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -39,6 +39,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -352,6 +353,7 @@
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
populateMathToSPIRVPatterns(typeConverter, patterns);
+ populateExpandCtlzPattern(patterns);
// Pull in standard patterns to convert tensor operations to SPIR-V. These are
// primarily used to handle tensor-type constants and contain a
diff --git a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp
index bf8bf1d..ef2f51d 100644
--- a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp
+++ b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp
@@ -16,7 +16,10 @@
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -67,8 +70,10 @@
conversionTarget.addIllegalDialect<tensor::TensorDialect>();
conversionTarget.addLegalDialect<IREE::Util::UtilDialect>();
conversionTarget.addLegalDialect<IREE::VMVX::VMVXDialect>();
- conversionTarget.addLegalDialect<mlir::func::FuncDialect,
- mlir::arith::ArithmeticDialect>();
+ conversionTarget
+ .addLegalDialect<mlir::func::FuncDialect, mlir::scf::SCFDialect,
+ mlir::arith::ArithmeticDialect>();
+ conversionTarget.addIllegalOp<math::CountLeadingZerosOp>();
conversionTarget.addLegalDialect<mlir::AffineDialect>();
conversionTarget.addLegalDialect<memref::MemRefDialect>();
conversionTarget.addLegalOp<mlir::UnrealizedConversionCastOp>();
@@ -77,6 +82,7 @@
populateHALToVMVXPatterns(context, conversionPatterns, typeConverter);
populateStandardToVMVXPatterns(context, conversionPatterns, typeConverter);
+ populateExpandCtlzPattern(conversionPatterns);
// Use the default 64-bit lowering for TOSA's ApplyScale operator:
// This lowering widens integer types to 64-bit an performs the non-fused
// operations, specifically multiply, add, and shift. Bit-widening
diff --git a/tests/e2e/tosa_ops/BUILD b/tests/e2e/tosa_ops/BUILD
index 92afeca..27052b3 100644
--- a/tests/e2e/tosa_ops/BUILD
+++ b/tests/e2e/tosa_ops/BUILD
@@ -83,6 +83,7 @@
"bitwise_xor.mlir",
"ceil.mlir",
"clamp.mlir",
+ "clz.mlir",
"const.mlir",
"equal.mlir",
"exp.mlir",
@@ -116,7 +117,6 @@
],
include = ["*.mlir"],
exclude = [
- "clz.mlir", # https://github.com/google/iree/issues/9152
"reduce.mlir", # Currently flakey https://github.com/google/iree/issues/5885
],
)
@@ -141,6 +141,7 @@
"bitwise_xor.mlir",
"ceil.mlir",
"clamp.mlir",
+ "clz.mlir",
"const.mlir",
"equal.mlir",
"exp.mlir",
@@ -174,9 +175,6 @@
"while.mlir",
],
include = ["*.mlir"],
- exclude = [
- "clz.mlir", # https://github.com/google/iree/issues/9152
- ],
)
iree_check_single_backend_test_suite(
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 6dc3430..61050c4 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -74,6 +74,7 @@
"bitwise_xor.mlir"
"ceil.mlir"
"clamp.mlir"
+ "clz.mlir"
"const.mlir"
"equal.mlir"
"exp.mlir"
@@ -124,6 +125,7 @@
"bitwise_xor.mlir"
"ceil.mlir"
"clamp.mlir"
+ "clz.mlir"
"const.mlir"
"equal.mlir"
"exp.mlir"