Added bf16 lowerings for cuda backend (#13351)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f9f5760..abb0547 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -514,7 +514,8 @@
// math dialect elementry functions -> polynomial form.
pm.addNestedPass<func::FuncOp>(createPolynomialApproximationPass());
- pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
+ pm.addNestedPass<func::FuncOp>(
+ arith::createArithExpandOpsPass({/*include-bf16=*/true}));
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createLowerAffinePass());
diff --git a/tests/e2e/xla_ops/BUILD.bazel b/tests/e2e/xla_ops/BUILD.bazel
index b047767..0a997ba 100644
--- a/tests/e2e/xla_ops/BUILD.bazel
+++ b/tests/e2e/xla_ops/BUILD.bazel
@@ -40,6 +40,7 @@
"cosine.mlir",
"divide.mlir",
"dot.mlir",
+ "dot_bf16.mlir",
"dot_general.mlir",
"dynamic_slice.mlir",
"dynamic_update_slice.mlir",
@@ -82,7 +83,6 @@
],
include = ["*.mlir"],
exclude = [
- "dot_bf16.mlir", # Missing BF16 support on CUDA backend
"fft.mlir", # TODO(#9583)
],
),
@@ -123,6 +123,7 @@
"cosine.mlir",
"divide.mlir",
"dot.mlir",
+ "dot_bf16.mlir",
"dot_general.mlir",
"dynamic_slice.mlir",
"dynamic_update_slice.mlir",
@@ -165,7 +166,6 @@
],
include = ["*.mlir"],
exclude = [
- "dot_bf16.mlir", # Missing BF16 support on CUDA backend
"fft.mlir", # TODO(#9583)
],
),
diff --git a/tests/e2e/xla_ops/CMakeLists.txt b/tests/e2e/xla_ops/CMakeLists.txt
index 877bc86..e90c5be 100644
--- a/tests/e2e/xla_ops/CMakeLists.txt
+++ b/tests/e2e/xla_ops/CMakeLists.txt
@@ -31,6 +31,7 @@
"cosine.mlir"
"divide.mlir"
"dot.mlir"
+ "dot_bf16.mlir"
"dot_general.mlir"
"dynamic_slice.mlir"
"dynamic_update_slice.mlir"
@@ -107,6 +108,7 @@
"cosine.mlir"
"divide.mlir"
"dot.mlir"
+ "dot_bf16.mlir"
"dot_general.mlir"
"dynamic_slice.mlir"
"dynamic_update_slice.mlir"