Add workgroup swizzling for better cache reuse (#8789)

This PR adds a pass to swizzle workgroup ids for better cache use. 
Currently this is experimental and is done by passing a global option for the log of the swizzle tile size desired. By default the pass does not do any swizzling and will not change the generated code.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index a3359f6..d681159 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -56,6 +56,7 @@
         "TypePropagationPass.cpp",
         "VectorizeConv.cpp",
         "VectorizeMMT4d.cpp",
+        "WorkGroupSwizzle.cpp",
     ],
     hdrs = [
         "BufferizationAnalysis.h",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index af26185..c613a6b 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -49,6 +49,7 @@
     "TypePropagationPass.cpp"
     "VectorizeConv.cpp"
     "VectorizeMMT4d.cpp"
+    "WorkGroupSwizzle.cpp"
   DEPS
     IREELinalgExtDialect
     IREELinalgExtPasses
diff --git a/iree/compiler/Codegen/Common/WorkGroupSwizzle.cpp b/iree/compiler/Codegen/Common/WorkGroupSwizzle.cpp
new file mode 100644
index 0000000..39162c6
--- /dev/null
+++ b/iree/compiler/Codegen/Common/WorkGroupSwizzle.cpp
@@ -0,0 +1,115 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// This function implements the following swizzling logic
+/// void getTiledId2(unsigned x, unsigned y, unsigned* tiledx,
+///                 unsigned* tiledy) {
+///  unsigned t_tiledx = (x + (y % tile) * grid_size_x) / tile;
+///  unsigned t_tiledy = (y / tile) * tile +
+///      (x + (y % tile) * grid_size_x) % tile;
+///  bool c = grid_size_y % tile != 0 &&
+///      ((y / tile) * tile + tile) > grid_size_y;
+///  *tiledx = c ? x : t_tiledx;
+///  *tiledy = c ? y : t_tiledy;
+/// }
+// TODO: Make this a callback and the core functionality in the pass a utility
+// function.
+static void makeSwizzledId(Location loc, OpBuilder b, Value workgroupIdX,
+                           Value workgroupIdY, Value gridSizeX, Value gridSizeY,
+                           Value& SwizzledIdX, Value& SwizzledIdY,
+                           unsigned swizzleTile) {
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value tile = b.create<arith::ConstantIndexOp>(loc, swizzleTile);
+  Value yModTile = b.create<arith::RemUIOp>(loc, workgroupIdY, tile);
+  Value yDivTile = b.create<arith::DivUIOp>(loc, workgroupIdY, tile);
+  Value swizzleParam = b.create<arith::MulIOp>(loc, yModTile, gridSizeX);
+  Value swizzleParam2 =
+      b.create<arith::AddIOp>(loc, workgroupIdX, swizzleParam);
+  Value swizzleParam3 = b.create<arith::RemUIOp>(loc, swizzleParam2, tile);
+  Value swizzleParam4 = b.create<arith::MulIOp>(loc, yDivTile, tile);
+  Value unboundedSwizzledIdX =
+      b.create<arith::DivUIOp>(loc, swizzleParam2, tile);
+  Value unboundedSwizzledIdY =
+      b.create<arith::AddIOp>(loc, swizzleParam3, swizzleParam4);
+  Value gyModTile = b.create<arith::RemUIOp>(loc, gridSizeY, tile);
+  Value gyAddTile = b.create<arith::AddIOp>(loc, swizzleParam4, tile);
+  Value condition1 =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, gyModTile, zero);
+  Value condition2 = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+                                             gyAddTile, gridSizeY);
+  Value condition3 = b.create<arith::AndIOp>(loc, condition1, condition2);
+  SwizzledIdX = b.create<arith::SelectOp>(loc, condition3, workgroupIdX,
+                                          unboundedSwizzledIdX);
+  SwizzledIdY = b.create<arith::SelectOp>(loc, condition3, workgroupIdY,
+                                          unboundedSwizzledIdY);
+}
+namespace {
+struct WorkGroupSwizzlePass
+    : public WorkGroupSwizzleBase<WorkGroupSwizzlePass> {
+  WorkGroupSwizzlePass(unsigned swizzleLogTile)
+      : swizzleLogTile(swizzleLogTile) {}
+
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<AffineDialect>();
+  }
+  LogicalResult initializeOptions(StringRef options) override {
+    if (failed(Pass::initializeOptions(options))) {
+      return failure();
+    }
+    swizzleLogTile = logTile;
+    return success();
+  }
+  void runOnOperation() override {
+    if (swizzleLogTile == 0) return;
+    unsigned swizzleTile = pow(2, swizzleLogTile);
+    FuncOp funcOp = getOperation();
+    std::array<IREE::HAL::InterfaceWorkgroupIDOp, 2> oldWorkgroupIds;
+    bool xFound = false, yFound = false;
+    funcOp.walk([&](IREE::HAL::InterfaceWorkgroupIDOp idOp) {
+      unsigned index = idOp.dimension().getZExtValue();
+      if (index == 0) {
+        oldWorkgroupIds[index] = idOp;
+        xFound = true;
+      } else if (index == 1) {
+        oldWorkgroupIds[index] = idOp;
+        yFound = true;
+      }
+    });
+    if (xFound == false || yFound == false) return;
+    OpBuilder builder(funcOp);
+    builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
+    Value workgroupIdX =
+        builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 0);
+    Value workgroupIdY =
+        builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 1);
+    Value gridSizeX = builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(
+        funcOp.getLoc(), 0);
+    Value gridSizeY = builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(
+        funcOp.getLoc(), 1);
+    Value SwizzledIdX, SwizzledIdY;
+    makeSwizzledId(funcOp.getLoc(), builder, workgroupIdX, workgroupIdY,
+                   gridSizeX, gridSizeY, SwizzledIdX, SwizzledIdY, swizzleTile);
+    oldWorkgroupIds[0].replaceAllUsesWith(SwizzledIdX);
+    oldWorkgroupIds[1].replaceAllUsesWith(SwizzledIdY);
+  }
+
+ private:
+  unsigned swizzleLogTile;
+};
+}  // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
+    unsigned swizzleLogTile) {
+  return std::make_unique<WorkGroupSwizzlePass>(swizzleLogTile);
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index f9842b7..d6abde2 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -34,6 +34,7 @@
             "remove_dead_allocs.mlir",
             "remove_trivial_loops.mlir",
             "rewrite_linalg_destructive_updates.mlir",
+            "swizzle_workgroup.mlir",
             "tile_and_distribute_to_workgroups.mlir",
             "transpose_canonicalization.mlir",
             "type_propagation.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index a372c10..f45b163 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -29,6 +29,7 @@
     "remove_dead_allocs.mlir"
     "remove_trivial_loops.mlir"
     "rewrite_linalg_destructive_updates.mlir"
+    "swizzle_workgroup.mlir"
     "tile_and_distribute_to_workgroups.mlir"
     "transpose_canonicalization.mlir"
     "type_propagation.mlir"
diff --git a/iree/compiler/Codegen/Common/test/swizzle_workgroup.mlir b/iree/compiler/Codegen/Common/test/swizzle_workgroup.mlir
new file mode 100644
index 0000000..73e70bd
--- /dev/null
+++ b/iree/compiler/Codegen/Common/test/swizzle_workgroup.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-opt --iree-workgroup-swizzle='logTile=3' %s | FileCheck %s
+
+func @matmul() {
+  %c0 = arith.constant 0 : index
+  %c128 = arith.constant 128 : index
+  %c96 = arith.constant 96 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x4096xf32>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:4096x96xf32>
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x96xf32>
+  %3 = linalg.init_tensor [128, 96] : tensor<128x96xf32>
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_count_x = hal.interface.workgroup.count[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_count_y = hal.interface.workgroup.count[1] : index
+  %4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+  %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_y]
+  scf.for %arg0 = %4 to %c128 step %5 {
+    %6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
+    %7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
+    scf.for %arg1 = %6 to %c96 step %7 {
+      %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [32, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x4096xf32> -> tensor<32x4096xf32>
+      %9 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:4096x96xf32> -> tensor<4096x32xf32>
+      %10 = tensor.extract_slice %3[%arg0, %arg1] [32, 32] [1, 1] : tensor<128x96xf32> to tensor<32x32xf32>
+      %11 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%8, %9 : tensor<32x4096xf32>, tensor<4096x32xf32>) outs(%10 : tensor<32x32xf32>) -> tensor<32x32xf32>
+      flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1], sizes = [32, 32], strides = [1, 1] : tensor<32x32xf32> -> !flow.dispatch.tensor<writeonly:128x96xf32>
+    }
+  }
+  return
+}
+
+//    CHECK-LABEL: func @matmul
+//          CHECK: %[[WORKGROUPIDX:.*]] = hal.interface.workgroup.id[0] : index
+//          CHECK: %[[WORKGROUPIDY:.*]] = hal.interface.workgroup.id[1] : index
+//          CHECK: %[[WORKGROUPCOUNTX:.*]] = hal.interface.workgroup.count[0] : index
+//          CHECK: %[[WORKGROUPCOUNTY:.*]] = hal.interface.workgroup.count[1] : index
+//          CHECK: %[[CST0:.*]] = arith.constant 0 : index
+//          CHECK: %[[CST8:.*]] = arith.constant 8 : index
+//          CHECK: %[[S0:.*]] = arith.remui %[[WORKGROUPIDY]], %[[CST8]] : index
+//          CHECK: %[[S1:.*]] = arith.divui %[[WORKGROUPIDY]], %[[CST8]] : index
+//          CHECK: %[[S2:.*]] = arith.muli %[[S0]], %[[WORKGROUPCOUNTX]] : index
+//          CHECK: %[[S3:.*]] = arith.addi %[[WORKGROUPIDX]], %[[S2]] : index
+//          CHECK: %[[S4:.*]] = arith.remui %[[S3]], %[[CST8]] : index
+//          CHECK: %[[S5:.*]] = arith.muli %[[S1]], %[[CST8]] : index
+//          CHECK: %[[S6:.*]] = arith.divui %[[S3]], %[[CST8]] : index
+//          CHECK: %[[S7:.*]] = arith.addi %[[S4]], %[[S5]] : index
+//          CHECK: %[[S8:.*]] = arith.remui %[[WORKGROUPCOUNTY]], %[[CST8]] : index
+//          CHECK: %[[S9:.*]] = arith.addi %[[S5]], %[[CST8]] : index
+//          CHECK: %[[S10:.*]] = arith.cmpi ne, %[[S8]], %[[CST0]] : index
+//          CHECK: %[[S11:.*]] = arith.cmpi ugt, %[[S9]], %[[WORKGROUPCOUNTY]] : index
+//          CHECK: %[[S12:.*]] = arith.andi %[[S10]], %[[S11]] : i1
+//          CHECK: %[[S13:.*]] = arith.select %[[S12]], %[[WORKGROUPIDX]], %[[S6]] : index
+//          CHECK: %[[S14:.*]] = arith.select %[[S12]], %[[WORKGROUPIDY]], %[[S7]] : index
+
+
+
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index c313f8b..a70bc33 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -29,6 +29,10 @@
                                              llvm::cl::desc("Pipeline depth"),
                                              llvm::cl::init(4));
 
+static llvm::cl::opt<unsigned> logSwizzleTile(
+    "iree-codegen-log-swizzle-tile", llvm::cl::desc("log swizzle tile value"),
+    llvm::cl::init(0));
+
 static Value gpuAllocationFunction(OpBuilder &builder, Location loc,
                                    ArrayRef<int64_t> staticShape,
                                    Type elementType,
@@ -91,6 +95,9 @@
   pm.addNestedPass<func::FuncOp>(
       createLLVMGPUReduceSharedMemoryBankConflicts());
   pm.addNestedPass<func::FuncOp>(createRemoveSingleIterationLoopPass());
+  pm.addNestedPass<func::FuncOp>(createWorkGroupSwizzle(logSwizzleTile));
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
 
   // Linalg -> vector
   pm.addNestedPass<func::FuncOp>(createLLVMGPUVectorizationPass());
@@ -118,6 +125,9 @@
   pm.addNestedPass<func::FuncOp>(
       createLLVMGPUReduceSharedMemoryBankConflicts());
   pm.addNestedPass<func::FuncOp>(createRemoveSingleIterationLoopPass());
+  pm.addNestedPass<func::FuncOp>(createWorkGroupSwizzle(logSwizzleTile));
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
 
   // Linalg -> vector
   pm.addNestedPass<func::FuncOp>(createLLVMGPUTensorCoreVectorizationPass());
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index fa25554..a89bbb6 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -153,6 +153,10 @@
 /// Creates a pass to convert memref.copy to linalg op.
 std::unique_ptr<OperationPass<func::FuncOp>> createMemrefCopyToLinalgPass();
 
+/// Converts vector ops to gpu dialect.
+std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
+    unsigned swizzleLogTile = 0);
+
 //----------------------------------------------------------------------------//
 // Common codegen patterns.
 //----------------------------------------------------------------------------//
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index eab90a1..d4f23b6 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -162,6 +162,18 @@
   let constructor =
       "mlir::iree_compiler::createMemrefCopyToLinalgPass()";
 }
+
+def WorkGroupSwizzle :
+    Pass<"iree-workgroup-swizzle", "func::FuncOp"> {
+  let summary = "swizzle the workgroup ids for better cache reuse";
+  let constructor = "mlir::iree_compiler::createWorkGroupSwizzle()";
+  let options = [
+    Option<"logTile", "logTile", "unsigned",
+            /*default=*/"0",
+           "pass the tile value for unit testing">,
+  ];
+}
+
 //------------------------------------------------------------------------------
 // LLVMCPU
 //------------------------------------------------------------------------------