[CPU] Add missing BitCast lowering patterns. (#17810)
The revision drops the local LLVM revert
https://github.com/llvm/llvm-project/commit/137a7451f458cf7d8e1d88df93dbd8da6888886d
Fixes https://github.com/iree-org/iree/issues/17780
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski@arm.com>
---------
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 22b4ce0..6029d2e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -1049,6 +1049,9 @@
arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
arith::populateExpandBFloat16Patterns(patterns);
populateVectorToSCFConversionPatterns(patterns);
+ // Some n-D vectors are generated by EmulateNarrowType pass, so we need to
+ // unroll them to 1-D before converting to the LLVM dialect.
+ vector::populateVectorBitCastLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns);
populateVectorToLLVMConversionPatterns(
typeConverter, patterns, targetReassociateFpReductions.getValue());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir
index 7aa3a3c..1d5686e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir
@@ -54,7 +54,7 @@
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c4096) flags(ReadOnly) : memref<128xi8, strided<[1], offset: 4096>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c8192) : memref<256x64xi8, strided<[64, 1], offset: 8192>>
+ %out_buffer = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c8192) : memref<256x64xi4, strided<[64, 1], offset: 8192>>
%2 = vector.load %0[%c0] : memref<128xi8, strided<[1], offset: 4096>>, vector<2xi8>
%3 = vector.bitcast %2 : vector<2xi8> to vector<4xi4>
%4 = vector.insert %3, %cst_0 [3] : vector<4xi4> into vector<4x4xi4>
@@ -62,15 +62,12 @@
%6 = arith.shli %5, %cst : vector<4x2xi8>
%7 = arith.shrsi %6, %cst : vector<4x2xi8>
%8 = arith.shrsi %5, %cst : vector<4x2xi8>
+
+ // Ops that should be lowered
%9 = vector.interleave %7, %8 : vector<4x2xi8> -> vector<4x4xi8>
- %10 = vector.extract %9[0] : vector<4xi8> from vector<4x4xi8>
- %11 = vector.extract %9[1] : vector<4xi8> from vector<4x4xi8>
- %12 = vector.extract %9[2] : vector<4xi8> from vector<4x4xi8>
- %13 = vector.extract %9[3] : vector<4xi8> from vector<4x4xi8>
- vector.store %10, %1[%c0, %c0] : memref<256x64xi8, strided<[64, 1], offset: 8192>>, vector<4xi8>
- vector.store %11, %1[%c1, %c0] : memref<256x64xi8, strided<[64, 1], offset: 8192>>, vector<4xi8>
- vector.store %12, %1[%c2, %c0] : memref<256x64xi8, strided<[64, 1], offset: 8192>>, vector<4xi8>
- vector.store %13, %1[%c3, %c0] : memref<256x64xi8, strided<[64, 1], offset: 8192>>, vector<4xi8>
+ %14 = vector.bitcast %9 : vector<4x4xi8> to vector<4x8xi4>
+
+ vector.store %14, %out_buffer[%c0, %c0] : memref<256x64xi4, strided<[64, 1], offset: 8192>>, vector<4x8xi4>
return
}
}
@@ -79,5 +76,8 @@
// corresponding multi-dimensional `vector.bitcast`.
// CHECK-LABEL: llvm.func @interleave_and_bitcast_lowering(
-// CHECK-NOT: vector.bitcast %{{.*}} : vector<4x4xi4> to vector<4x2xi8>
+// vector.interleave should be gone entirely
// CHECK-NOT: vector.interleave
+// 2D vector.bitcast tha followed should be replaced with 1D vector.bitcast
+// CHECK: llvm.bitcast {{.*}} : vector<4xi8> to vector<8xi4>
+// CHECK-NOT: vector.bitcast %{{.*}} : vector<4x4xi8> to vector<4x8xi4>
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 4334375..c5bb6d3 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 4334375d5666c628610fc500aaab2059fd3892ba
+Subproject commit c5bb6d3e2eb870b5ae454b410ac190ea05045303