[Codegen] Add XOR-based Swizzle Attribute (#21562)

This adds a new swizzle attribute : xor_shuffle. 
It swizzles a element in `(row, col)` into `(row, col_swizzled)` with
`col_swizzled = ((row/perPhase) % maxPhase) ^ (col )`.

Definition is : 
`#iree_codegen.xor_shuffle<row_width, access_width, row_stride,
per_phase>`

By default, row_stride == row_width and per_phase=1
Example usage :
```
%alloc = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
        %alloc_swizzle = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<32768xi8, #gpu.address_space<workgroup>>
```

To do reverse swizzling on GMEM loads using global_to_lds:
`%val = iree_codegen.swizzle_hint
%rawBuffer[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<?xi8,
strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>`

---------

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
index f080e20..3ace270 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
@@ -187,6 +187,11 @@
       continue;
     }
     if (auto gatherToLDSOp = dyn_cast<amdgpu::GatherToLDSOp>(user)) {
+      // Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a
+      // subgroup contiguously in order of lane ID
+      if (gatherToLDSOp.getDst() == hintOp) {
+        continue;
+      }
       int64_t accessBitWidth = cast<MemRefType>(hintOp.getOperand().getType())
                                    .getElementTypeBitWidth() *
                                accessWidth;
@@ -201,11 +206,11 @@
       if (accessBitWidth != transferBitWidth) {
         return;
       }
-
       gatherToLDSOps.push_back(gatherToLDSOp);
       continue;
     }
-    // Bail out if we can't rewrite all users.
+    // Throw if we can't rewrite all users.
+    hintOp.emitError() << "unsupported SwizzleHintOp user: " << user;
     return;
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
index badf510..90c56b4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-resolve-swizzle-hints, canonicalize, cse))" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-resolve-swizzle-hints, canonicalize, cse))" --verify-diagnostics \
 // RUN:   --split-input-file --mlir-print-local-scope %s | FileCheck %s
 
 func.func @swizzle_load(%src: memref<?xf32>) -> vector<4xf32> {
@@ -70,6 +70,7 @@
 // -----
 
 func.func @drop_swizzle_non_access_user(%src: memref<?xf32>) -> (memref<?xf32>, vector<4xf32>) {
+  // expected-error @+1 {{unsupported SwizzleHintOp user}}
   %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<?xf32>
   %offset = arith.constant 68 : index
   %1 = vector.load %0[%offset] : memref<?xf32>, vector<4xf32>
@@ -243,3 +244,82 @@
 //       CHECK:   %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
 //       CHECK:   %[[SWOFF:.+]] = arith.addi %[[ROTATEJ]], %[[IELEM]] : index
 //       CHECK:   amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[DSTOFFSET]]]
+
+
+func.func @swizzle_load_xor(%src: memref<?xi8>) -> vector<16xi8> {
+  %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16>] : memref<?xi8>
+
+  //((int(1952/128) % 8 )^(int(1952/16) %8))*16+ int(1952/128)*128 -> 2000
+  %offset = arith.constant 1952 : index
+  %1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
+  return %1: vector<16xi8>
+}
+
+// CHECK-LABEL: func @swizzle_load_xor
+//  CHECK-SAME:   %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
+//       CHECK:   %[[SWOFF:.+]] = arith.constant 2000 : index
+//       CHECK:   %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+//       CHECK:   return %[[VECTOR]]
+
+// -----
+
+func.func @swizzle_load_xor_phase2(%src: memref<?xi8>) -> vector<16xi8> {
+  %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16, 128, 2>] : memref<?xi8>
+
+  %offset = arith.constant 1056 : index
+  %1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
+  return %1: vector<16xi8>
+}
+
+// CHECK-LABEL: func @swizzle_load_xor_phase2
+//  CHECK-SAME:   %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
+//       CHECK:   %[[SWOFF:.+]] = arith.constant 1120 : index
+//       CHECK:   %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+//       CHECK:   return %[[VECTOR]]
+
+// -----
+
+
+func.func @swizzle_raw_buffer_to_lds(%global : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  //1 row, 3rd tile : 1*8192+2*128 = 8448 -> (0 XOR 1)*16+8448 = 8464
+  %offset = arith.constant 8448 : index
+  %lds = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+  %globalSwizzle = iree_codegen.swizzle_hint %global[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+  amdgpu.gather_to_lds %globalSwizzle[%offset], %lds[%c0]
+    : vector<16xi8>, memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>, memref<32768xi8, #gpu.address_space<workgroup>>
+
+  func.return
+}
+
+// CHECK-LABEL: func @swizzle_raw_buffer_to_lds
+//  CHECK-SAME:   %[[SRC:.+]]: memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+//   CHECK:   %[[SWOFF:.+]] = arith.constant 8464 : index
+//   CHECK:   %[[LDSOFFSET:.+]] = arith.constant 0 : index
+//       CHECK:   %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+//       CHECK:   amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]
+
+// -----
+
+
+func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  //1 row, 3rd tile : 1*8192+2*128 = 8448 -> (0 XOR 1)*16+8448 = 8464
+  %offset = arith.constant 8448 : index
+  %lds = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+  %ldsSwizzle = iree_codegen.swizzle_hint %lds[#iree_codegen.xor_shuffle<128, 16>] : memref<32768xi8, #gpu.address_space<workgroup>>
+  %globalSwizzle = iree_codegen.swizzle_hint %global[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+  amdgpu.gather_to_lds %globalSwizzle[%offset], %ldsSwizzle[%c0]
+    : vector<16xi8>, memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>, memref<32768xi8, #gpu.address_space<workgroup>>
+
+  func.return
+}
+
+// CHECK-LABEL: func @swizzle_raw_buffer_to_lds_ignore_dst_op
+//  CHECK-SAME:   %[[SRC:.+]]: memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
+//   CHECK:   %[[SWOFF:.+]] = arith.constant 8464 : index
+//   CHECK:   %[[LDSOFFSET:.+]] = arith.constant 0 : index
+//       CHECK:   %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
+//       CHECK:   amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index c26fd24..16594be 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -612,6 +612,117 @@
   return symbolTable.lookup(name);
 }
 
+//===---------------------------------------------------------------------===//
+// iree_codegen.xor_shuffle
+//===---------------------------------------------------------------------===//
+
+/// Extract column index for XOR swizzling.
+/// ((id%rowStride) / accessWidth)
+static Value extractCol(OpBuilder &builder, Location loc, OpFoldResult id,
+                        OpFoldResult rowAlignment, OpFoldResult accessWidth) {
+  AffineExpr d0, s0, s1;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineExpr result = (d0 % s0).floorDiv(s1);
+  return getValueOrCreateConstantIndexOp(
+      builder, loc,
+      affine::makeComposedFoldedAffineApply(builder, loc, result,
+                                            {id, rowAlignment, accessWidth}));
+}
+
+/// Extract row index for XOR swizzling.
+/// row = ((id/rowStride) / perPhase ) % rowAccessAlignment
+static Value extractRow(OpBuilder &builder, Location loc, OpFoldResult id,
+                        OpFoldResult rowStride, OpFoldResult perPhase,
+                        OpFoldResult rowAccessAlignment) {
+  AffineExpr d0, s0, s1, s2;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), s0, s1, s2);
+  AffineExpr result = (d0.floorDiv(s0).floorDiv(s1)) % s2;
+  return getValueOrCreateConstantIndexOp(
+      builder, loc,
+      affine::makeComposedFoldedAffineApply(
+          builder, loc, result, {id, rowStride, perPhase, rowAccessAlignment}));
+}
+
+/// Swizzle column on id.
+/// new_id = id-id%rowAlignmentVal+colSwizzled*accessWidthVal
+static Value updateCol(OpBuilder &builder, Location loc, OpFoldResult id,
+                       Value colSwizzled, OpFoldResult rowAlignment,
+                       OpFoldResult accessWidth) {
+  AffineExpr d0, d1, s0, s1;
+  bindDims(builder.getContext(), d0, d1);
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineExpr result = d0 - d0 % s0 + d1 * s1;
+  return getValueOrCreateConstantIndexOp(
+      builder, loc,
+      affine::makeComposedFoldedAffineApply(
+          builder, loc, result, {id, colSwizzled, rowAlignment, accessWidth}));
+}
+
+OpFoldResult XORShuffleAttr::swizzleOffset(OpBuilder &b, Location loc,
+                                           OpFoldResult offset,
+                                           Value src) const {
+  int64_t rotationInvariant =
+      getRowWidth() * (getRowWidth() / getAccessWidth());
+  int64_t rowStride =
+      getRowStride() != int64_t() ? getRowStride() : getRowWidth();
+  int64_t perPhase = getPerPhase() != int64_t() ? getPerPhase() : 1;
+
+  OpFoldResult id =
+      getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
+  Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
+
+  // Number of elements per row.
+  Value rowAlignmentVal = b.create<arith::ConstantIndexOp>(loc, getRowWidth());
+  // Number of elements per group.
+  Value accessWidthVal =
+      b.create<arith::ConstantIndexOp>(loc, getAccessWidth());
+  // Number of rows per phase.
+  Value perPhaseVal = b.create<arith::ConstantIndexOp>(loc, perPhase);
+  // Buffer stride.
+  Value rowStrideVal = b.create<arith::ConstantIndexOp>(loc, rowStride);
+  // Number of contiguous groups of elements per row (swizzled together).
+  Value rowAccessAlignmentVal =
+      b.create<arith::ConstantIndexOp>(loc, getRowWidth() / getAccessWidth());
+
+  Value colVal = extractCol(b, loc, idVal, rowAlignmentVal, accessWidthVal);
+  Value rowVal = extractRow(b, loc, idVal, rowStrideVal, perPhaseVal,
+                            rowAccessAlignmentVal);
+  auto colSwizzled = b.create<arith::XOrIOp>(loc, rowVal, colVal);
+
+  // Update colSwizzled to initial id
+  Value swizzledIdVal =
+      updateCol(b, loc, idVal, colSwizzled, rowAlignmentVal, accessWidthVal);
+  Value diff = b.create<arith::SubIOp>(loc, swizzledIdVal, idVal);
+  return b
+      .create<arith::AddIOp>(
+          loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
+      .getResult();
+}
+
+int64_t XORShuffleAttr::getAccessElementCount() const {
+  return getAccessWidth();
+}
+
+LogicalResult
+XORShuffleAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                       int64_t rowWidth, int64_t accessWidth, int64_t rowStride,
+                       int64_t perPhase) {
+  if (rowWidth % accessWidth != 0) {
+    return emitError() << "expected access width to divide row width";
+  }
+  int64_t maxPhase = rowWidth / accessWidth;
+  if (perPhase > maxPhase) {
+    return emitError() << "per_phase must be smaller than max_phase";
+  }
+  if (rowStride % rowWidth != 0) {
+    return emitError() << "expected row width to divide row stride";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Initialize attributes
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index aa012e3..86a3bbf 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -549,4 +549,73 @@
   }];
 }
 
+
+//===---------------------------------------------------------------------===//
+// iree_codegen.xor_shuffle
+//===---------------------------------------------------------------------===//
+
+def IREECodegen_XORShuffleAttr  :
+    AttrDef<IREECodegen_Dialect, "XORShuffle", [
+    DeclareAttrInterfaceMethods<IREECodegen_SwizzleAttrInterface, [
+        "swizzleOffset",
+        "getAccessElementCount"
+      ]>
+    ]> {
+  let mnemonic = "xor_shuffle";
+  let summary = "An attribute that describes an XOR-based swizzling pattern.";
+  let description = [{
+    Shuffles accesses of |access_width| within rows of size
+    |row_width|. For any given access into logical memref of shape
+    `memref<...xNx|access_width|x!eltype>` where `N = row_width / access_width`
+    at position `(i, j, 0)` is shuffled to `(i, ((i/per_phase) %N) XOR j , 0)`. For example,
+
+    ```
+    row_width = 16, access_width = 4, per_phase = 1
+
+    0000 1111 2222 3333 /// 0 1 2 3
+    4444 5555 6666 7777 /// 0 1 2 3
+    8888 9999 AAAA BBBB /// 0 1 2 3
+    CCCC DDDD EEEE FFFF /// 0 1 2 3
+    ```
+
+    is swizzled to
+    ```
+    0000 1111 2222 3333 /// 0 1 2 3
+    7777 4444 5555 6666 /// 1 0 3 2
+    BBBB AAAA 8888 9999 /// 2 3 0 1
+    FFFF EEEE DDDD CCCC /// 3 2 1 0
+    ```
+    |access_width| allows to keep the same shuffling accross multiple rows. For example,
+
+    ```
+    row_width = 16, access_width = 4, per_phase = 2
+
+    0000 1111 2222 3333 /// 0 1 2 3
+    4444 5555 6666 7777 /// 0 1 2 3
+    8888 9999 AAAA BBBB /// 0 1 2 3
+    CCCC DDDD EEEE FFFF /// 0 1 2 3
+    ```
+
+    is swizzled to
+    ```
+    0000 1111 2222 3333 /// 0 1 2 3
+    7777 4444 5555 6666 /// 0 1 2 3
+    BBBB AAAA 8888 9999 /// 1 0 3 2
+    FFFF EEEE DDDD CCCC /// 1 0 3 2
+    ```
+
+    The pattern repeats for subsequent rows.
+  }];
+  let parameters = (ins
+    AttrParameter<"int64_t", "">:$row_width,
+    AttrParameter<"int64_t", "">:$access_width,
+    OptionalParameter<"int64_t", "row stride. Default to row_width">:$row_stride,
+    OptionalParameter<"int64_t", "Default to 1">:$per_phase
+  );
+  let assemblyFormat = [{
+    `<` $row_width `,` $access_width  (`,` $row_stride^)? (`,` $per_phase^)? `>`
+  }];
+  let genVerifyDecl = 1;
+}
+
 #endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index d9e2f49..994beb7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1208,6 +1208,8 @@
 
   if (forROCDL) {
     funcPassManager.addPass(amdgpu::createAmdgpuMaskedloadToLoadPass);
+    // This pass needs to run before the ResolveSwizzleHints pass.
+    funcPassManager.addPass(amdgpu::createAmdgpuFoldMemRefOpsPass);
   }
 
   // This pass needs to run before SCF -> CF.