[VMVX] Add support for arith.maxnumf and arith.minnumf lowering. (#18033)
Similar to what is done in CPU backend, it adds a ArithExpandOps pass to
the late stage of the pipeline. This drops the local llvm revert (i.e.,
https://github.com/llvm/llvm-project/commit/fa066687) because the VMVX
backend can emulate the ops.
The GPU test changes are because of the upstream commit. It uses
arith.maxnumf in softmax decomposition.
Note: the arith.maximum and arith.minimum ops are also expanded with the
change. They become a seq of cmp + select ops.
Existing local patches carried over:
- https://github.com/llvm/llvm-project/commit/f6431f0c
- https://github.com/llvm/llvm-project/commit/bbd4af5d
Fixes https://github.com/iree-org/iree/issues/17779
---------
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_softmax.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_softmax.mlir
index 4a84b95..777ca50 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_softmax.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_softmax.mlir
@@ -18,7 +18,7 @@
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D8:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16xf32>
// CHECK: %[[CST0:.+]] = arith.constant 0.0
@@ -54,7 +54,7 @@
// CHECK-NO-FUSE: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
// CHECK-NO-FUSE-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
// CHECK-NO-FUSE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-NO-FUSE: %[[D8:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// CHECK-NO-FUSE: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
// CHECK-NO-FUSE: linalg.yield %[[D8]] : f32
// CHECK-NO-FUSE: } -> tensor<2x16xf32>
// CHECK-NO-FUSE: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir
index 251f6de..28a10a9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir
@@ -232,19 +232,19 @@
// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-// CHECK: arith.maximumf {{.*}} : vector<4xf32>
+// CHECK: arith.maxnumf {{.*}} : vector<4xf32>
// CHECK: scf.yield
-// CHECK: vector.reduction <maximumf>, %{{.*}} : vector<4xf32> into f32
+// CHECK: vector.reduction <maxnumf>, %{{.*}} : vector<4xf32> into f32
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: arith.remui
// CHECK: scf.if
// CHECK: memref.store {{.*}} : memref<32xf32, #gpu.address_space<workgroup>>
@@ -253,16 +253,16 @@
// CHECK: arith.minui
// CHECK: memref.load
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
+// CHECK: arith.maxnumf
// CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
index 7e703df..9576c05 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
@@ -190,19 +190,19 @@
// CHECK-LABEL: func.func @softmax
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32{{.+}}>, vector<4xf32>
-// CHECK: arith.maximumf {{.*}} : vector<4xf32>
+// CHECK: arith.maxnumf {{.*}} : vector<4xf32>
// CHECK: scf.yield
-// CHECK: vector.reduction <maximumf>, %{{.*}} : vector<4xf32> into f32
+// CHECK: vector.reduction <maxnumf>, %{{.*}} : vector<4xf32> into f32
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: arith.remui
// CHECK: scf.if
// CHECK: memref.store {{.*}} : memref<32xf32, #gpu.address_space<workgroup>>
@@ -211,16 +211,16 @@
// CHECK: arith.minui
// CHECK: memref.load
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
-// CHECK: arith.maximumf
-// CHECK: arith.maximumf
+// CHECK: arith.maxnumf
+// CHECK: arith.maxnumf
// CHECK: vector.splat %{{.*}} : vector<4xf32>
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read
@@ -297,7 +297,7 @@
// CHECK-LABEL: func.func @dynamic_softmax
// CHECK-DAG: %[[ADD_PAD:.+]] = arith.constant dense<0.000000e+00> : vector<1xf16>
-// CHECK-DAG: %[[MIN_F16:.+]] = arith.constant dense<0xFC00> : vector<1xf16>
+// CHECK-DAG: %[[MIN_F16:.+]] = arith.constant dense<0xFE00> : vector<1xf16>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
// CHECK-DAG: %[[C0_F16:.+]] = arith.constant 0.000000e+00 : f16
@@ -316,7 +316,7 @@
// CHECK: %[[MASK:.+]] = vector.create_mask %{{.*}} : vector<1xi1>
// CHECK-DAG: %[[ACC:.+]] = vector.transfer_read %{{.*}}, %[[C0_F16]], %[[MASK]] {{.*}} : memref<1x64xf16, #gpu.address_space<workgroup>>, vector<1xf16>
// CHECK-DAG: %[[NEW:.+]] = vector.transfer_read %{{.*}}, %[[C0_F16]], %[[MASK]] {{.*}} : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
-// CHECK: %[[MAX:.+]] = arith.maximumf %[[NEW]], %[[ACC]] : vector<1xf16>
+// CHECK: %[[MAX:.+]] = arith.maxnumf %[[NEW]], %[[ACC]] : vector<1xf16>
// CHECK: vector.transfer_write %[[MAX]], %{{.*}}, %[[MASK]] {{.*}} : vector<1xf16>, memref<1x64xf16, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
index fc25a70..7d71d6a 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
@@ -81,6 +81,7 @@
.addPass(createCSEPass)
.addPass([]() { return createConvertVectorToSCFPass(); })
.addPass(createCanonicalizerPass)
+ .addPass(arith::createArithExpandOpsPass)
.addPass(memref::createExpandOpsPass);
// Handle tensor-type constants.
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 19b907f..e5d9ef7 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 19b907f2fe56bc71d9395b793ad995549f6fd401
+Subproject commit e5d9ef7704149011bcf68eb22854b6edfabe6634