[Codegen] Add XOR-based Swizzle Attribute (#21562)
This adds a new swizzle attribute : xor_shuffle.
It swizzles a element in `(row, col)` into `(row, col_swizzled)` with
`col_swizzled = ((row/perPhase) % maxPhase) ^ (col )`.
Definition is :
`#iree_codegen.xor_shuffle<row_width, access_width, row_stride,
per_phase>`
By default, row_stride == row_width and per_phase=1
Example usage :
```
%alloc = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
%alloc_swizzle = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<32768xi8, #gpu.address_space<workgroup>>
```
To do reverse swizzling on GMEM loads using global_to_lds:
`%val = iree_codegen.swizzle_hint
%rawBuffer[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<?xi8,
strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>`
---------
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
index f080e20..3ace270 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
@@ -187,6 +187,11 @@
continue;
}
if (auto gatherToLDSOp = dyn_cast<amdgpu::GatherToLDSOp>(user)) {
+ // Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a
+ // subgroup contiguously in order of lane ID
+ if (gatherToLDSOp.getDst() == hintOp) {
+ continue;
+ }
int64_t accessBitWidth = cast<MemRefType>(hintOp.getOperand().getType())
.getElementTypeBitWidth() *
accessWidth;
@@ -201,11 +206,11 @@
if (accessBitWidth != transferBitWidth) {
return;
}
-
gatherToLDSOps.push_back(gatherToLDSOp);
continue;
}
- // Bail out if we can't rewrite all users.
+ // Throw if we can't rewrite all users.
+ hintOp.emitError() << "unsupported SwizzleHintOp user: " << user;
return;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
index badf510..90c56b4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-resolve-swizzle-hints, canonicalize, cse))" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-resolve-swizzle-hints, canonicalize, cse))" --verify-diagnostics \
// RUN: --split-input-file --mlir-print-local-scope %s | FileCheck %s
func.func @swizzle_load(%src: memref<?xf32>) -> vector<4xf32> {
@@ -70,6 +70,7 @@
// -----
func.func @drop_swizzle_non_access_user(%src: memref<?xf32>) -> (memref<?xf32>, vector<4xf32>) {
+ // expected-error @+1 {{unsupported SwizzleHintOp user}}
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<?xf32>
%offset = arith.constant 68 : index
%1 = vector.load %0[%offset] : memref<?xf32>, vector<4xf32>
@@ -243,3 +244,82 @@
// CHECK: %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
// CHECK: %[[SWOFF:.+]] = arith.addi %[[ROTATEJ]], %[[IELEM]] : index
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[DSTOFFSET]]]
+
+
+func.func @swizzle_load_xor(%src: memref<?xi8>) -> vector<16xi8> {
+ %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16>] : memref<?xi8>
+
+ //((int(1952/128) % 8 )^(int(1952/16) %8))*16+ int(1952/128)*128 -> 2000
+ %offset = arith.constant 1952 : index
+ %1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
+ return %1: vector<16xi8>
+}
+
+// CHECK-LABEL: func @swizzle_load_xor
+// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
+// CHECK: %[[SWOFF:.+]] = arith.constant 2000 : index
+// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+// CHECK: return %[[VECTOR]]
+
+// -----
+
+func.func @swizzle_load_xor_phase2(%src: memref<?xi8>) -> vector<16xi8> {
+ %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16, 128, 2>] : memref<?xi8>
+
+ %offset = arith.constant 1056 : index
+ %1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
+ return %1: vector<16xi8>
+}
+
+// CHECK-LABEL: func @swizzle_load_xor_phase2
+// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
+// CHECK: %[[SWOFF:.+]] = arith.constant 1120 : index
+// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+// CHECK: return %[[VECTOR]]
+
+// -----
+
+
+func.func @swizzle_raw_buffer_to_lds(%global : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ //1 row, 3rd tile : 1*8192+2*128 = 8448 -> (0 XOR 1)*16+8448 = 8464
+ %offset = arith.constant 8448 : index
+ %lds = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+ %globalSwizzle = iree_codegen.swizzle_hint %global[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+ amdgpu.gather_to_lds %globalSwizzle[%offset], %lds[%c0]
+ : vector<16xi8>, memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>, memref<32768xi8, #gpu.address_space<workgroup>>
+
+ func.return
+}
+
+// CHECK-LABEL: func @swizzle_raw_buffer_to_lds
+// CHECK-SAME: %[[SRC:.+]]: memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK: %[[SWOFF:.+]] = arith.constant 8464 : index
+// CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index
+// CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]
+
+// -----
+
+
+func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ //1 row, 3rd tile : 1*8192+2*128 = 8448 -> (0 XOR 1)*16+8448 = 8464
+ %offset = arith.constant 8448 : index
+ %lds = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+ %ldsSwizzle = iree_codegen.swizzle_hint %lds[#iree_codegen.xor_shuffle<128, 16>] : memref<32768xi8, #gpu.address_space<workgroup>>
+ %globalSwizzle = iree_codegen.swizzle_hint %global[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+ amdgpu.gather_to_lds %globalSwizzle[%offset], %ldsSwizzle[%c0]
+ : vector<16xi8>, memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>, memref<32768xi8, #gpu.address_space<workgroup>>
+
+ func.return
+}
+
+// CHECK-LABEL: func @swizzle_raw_buffer_to_lds_ignore_dst_op
+// CHECK-SAME: %[[SRC:.+]]: memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK: %[[SWOFF:.+]] = arith.constant 8464 : index
+// CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index
+// CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index c26fd24..16594be 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -612,6 +612,117 @@
return symbolTable.lookup(name);
}
+//===---------------------------------------------------------------------===//
+// iree_codegen.xor_shuffle
+//===---------------------------------------------------------------------===//
+
+/// Extract column index for XOR swizzling.
+/// ((id%rowStride) / accessWidth)
+static Value extractCol(OpBuilder &builder, Location loc, OpFoldResult id,
+ OpFoldResult rowAlignment, OpFoldResult accessWidth) {
+ AffineExpr d0, s0, s1;
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), s0, s1);
+ AffineExpr result = (d0 % s0).floorDiv(s1);
+ return getValueOrCreateConstantIndexOp(
+ builder, loc,
+ affine::makeComposedFoldedAffineApply(builder, loc, result,
+ {id, rowAlignment, accessWidth}));
+}
+
+/// Extract row index for XOR swizzling.
+/// row = ((id/rowStride) / perPhase ) % rowAccessAlignment
+static Value extractRow(OpBuilder &builder, Location loc, OpFoldResult id,
+ OpFoldResult rowStride, OpFoldResult perPhase,
+ OpFoldResult rowAccessAlignment) {
+ AffineExpr d0, s0, s1, s2;
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), s0, s1, s2);
+ AffineExpr result = (d0.floorDiv(s0).floorDiv(s1)) % s2;
+ return getValueOrCreateConstantIndexOp(
+ builder, loc,
+ affine::makeComposedFoldedAffineApply(
+ builder, loc, result, {id, rowStride, perPhase, rowAccessAlignment}));
+}
+
+/// Swizzle column on id.
+/// new_id = id-id%rowAlignmentVal+colSwizzled*accessWidthVal
+static Value updateCol(OpBuilder &builder, Location loc, OpFoldResult id,
+ Value colSwizzled, OpFoldResult rowAlignment,
+ OpFoldResult accessWidth) {
+ AffineExpr d0, d1, s0, s1;
+ bindDims(builder.getContext(), d0, d1);
+ bindSymbols(builder.getContext(), s0, s1);
+ AffineExpr result = d0 - d0 % s0 + d1 * s1;
+ return getValueOrCreateConstantIndexOp(
+ builder, loc,
+ affine::makeComposedFoldedAffineApply(
+ builder, loc, result, {id, colSwizzled, rowAlignment, accessWidth}));
+}
+
+OpFoldResult XORShuffleAttr::swizzleOffset(OpBuilder &b, Location loc,
+ OpFoldResult offset,
+ Value src) const {
+ int64_t rotationInvariant =
+ getRowWidth() * (getRowWidth() / getAccessWidth());
+ int64_t rowStride =
+ getRowStride() != int64_t() ? getRowStride() : getRowWidth();
+ int64_t perPhase = getPerPhase() != int64_t() ? getPerPhase() : 1;
+
+ OpFoldResult id =
+ getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
+ Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
+
+ // Number of elements per row.
+ Value rowAlignmentVal = b.create<arith::ConstantIndexOp>(loc, getRowWidth());
+ // Number of elements per group.
+ Value accessWidthVal =
+ b.create<arith::ConstantIndexOp>(loc, getAccessWidth());
+ // Number of rows per phase.
+ Value perPhaseVal = b.create<arith::ConstantIndexOp>(loc, perPhase);
+ // Buffer stride.
+ Value rowStrideVal = b.create<arith::ConstantIndexOp>(loc, rowStride);
+ // Number of contiguous groups of elements per row (swizzled together).
+ Value rowAccessAlignmentVal =
+ b.create<arith::ConstantIndexOp>(loc, getRowWidth() / getAccessWidth());
+
+ Value colVal = extractCol(b, loc, idVal, rowAlignmentVal, accessWidthVal);
+ Value rowVal = extractRow(b, loc, idVal, rowStrideVal, perPhaseVal,
+ rowAccessAlignmentVal);
+ auto colSwizzled = b.create<arith::XOrIOp>(loc, rowVal, colVal);
+
+ // Update colSwizzled to initial id
+ Value swizzledIdVal =
+ updateCol(b, loc, idVal, colSwizzled, rowAlignmentVal, accessWidthVal);
+ Value diff = b.create<arith::SubIOp>(loc, swizzledIdVal, idVal);
+ return b
+ .create<arith::AddIOp>(
+ loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
+ .getResult();
+}
+
+int64_t XORShuffleAttr::getAccessElementCount() const {
+ return getAccessWidth();
+}
+
+LogicalResult
+XORShuffleAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ int64_t rowWidth, int64_t accessWidth, int64_t rowStride,
+ int64_t perPhase) {
+ if (rowWidth % accessWidth != 0) {
+ return emitError() << "expected access width to divide row width";
+ }
+ int64_t maxPhase = rowWidth / accessWidth;
+ if (perPhase > maxPhase) {
+ return emitError() << "per_phase must be smaller than max_phase";
+ }
+ if (rowStride % rowWidth != 0) {
+ return emitError() << "expected row width to divide row stride";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Initialize attributes
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index aa012e3..86a3bbf 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -549,4 +549,73 @@
}];
}
+
+//===---------------------------------------------------------------------===//
+// iree_codegen.xor_shuffle
+//===---------------------------------------------------------------------===//
+
+def IREECodegen_XORShuffleAttr :
+ AttrDef<IREECodegen_Dialect, "XORShuffle", [
+ DeclareAttrInterfaceMethods<IREECodegen_SwizzleAttrInterface, [
+ "swizzleOffset",
+ "getAccessElementCount"
+ ]>
+ ]> {
+ let mnemonic = "xor_shuffle";
+ let summary = "An attribute that describes an XOR-based swizzling pattern.";
+ let description = [{
+ Shuffles accesses of |access_width| within rows of size
+ |row_width|. For any given access into logical memref of shape
+ `memref<...xNx|access_width|x!eltype>` where `N = row_width / access_width`
+ at position `(i, j, 0)` is shuffled to `(i, ((i/per_phase) %N) XOR j , 0)`. For example,
+
+ ```
+ row_width = 16, access_width = 4, per_phase = 1
+
+ 0000 1111 2222 3333 /// 0 1 2 3
+ 4444 5555 6666 7777 /// 0 1 2 3
+ 8888 9999 AAAA BBBB /// 0 1 2 3
+ CCCC DDDD EEEE FFFF /// 0 1 2 3
+ ```
+
+ is swizzled to
+ ```
+ 0000 1111 2222 3333 /// 0 1 2 3
+ 7777 4444 5555 6666 /// 1 0 3 2
+ BBBB AAAA 8888 9999 /// 2 3 0 1
+ FFFF EEEE DDDD CCCC /// 3 2 1 0
+ ```
+ |access_width| allows to keep the same shuffling accross multiple rows. For example,
+
+ ```
+ row_width = 16, access_width = 4, per_phase = 2
+
+ 0000 1111 2222 3333 /// 0 1 2 3
+ 4444 5555 6666 7777 /// 0 1 2 3
+ 8888 9999 AAAA BBBB /// 0 1 2 3
+ CCCC DDDD EEEE FFFF /// 0 1 2 3
+ ```
+
+ is swizzled to
+ ```
+ 0000 1111 2222 3333 /// 0 1 2 3
+ 7777 4444 5555 6666 /// 0 1 2 3
+ BBBB AAAA 8888 9999 /// 1 0 3 2
+ FFFF EEEE DDDD CCCC /// 1 0 3 2
+ ```
+
+ The pattern repeats for subsequent rows.
+ }];
+ let parameters = (ins
+ AttrParameter<"int64_t", "">:$row_width,
+ AttrParameter<"int64_t", "">:$access_width,
+ OptionalParameter<"int64_t", "row stride. Default to row_width">:$row_stride,
+ OptionalParameter<"int64_t", "Default to 1">:$per_phase
+ );
+ let assemblyFormat = [{
+ `<` $row_width `,` $access_width (`,` $row_stride^)? (`,` $per_phase^)? `>`
+ }];
+ let genVerifyDecl = 1;
+}
+
#endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index d9e2f49..994beb7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1208,6 +1208,8 @@
if (forROCDL) {
funcPassManager.addPass(amdgpu::createAmdgpuMaskedloadToLoadPass);
+ // This pass needs to run before the ResolveSwizzleHints pass.
+ funcPassManager.addPass(amdgpu::createAmdgpuFoldMemRefOpsPass);
}
// This pass needs to run before SCF -> CF.