Add workgroup swizzling for better cache reuse (#8789)
This PR adds a pass to swizzle workgroup ids for better cache use.
Currently this is experimental and is done by passing a global option for the log of the swizzle tile size desired. By default the pass does not do any swizzling and will not change the generated code.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index a3359f6..d681159 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -56,6 +56,7 @@
"TypePropagationPass.cpp",
"VectorizeConv.cpp",
"VectorizeMMT4d.cpp",
+ "WorkGroupSwizzle.cpp",
],
hdrs = [
"BufferizationAnalysis.h",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index af26185..c613a6b 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -49,6 +49,7 @@
"TypePropagationPass.cpp"
"VectorizeConv.cpp"
"VectorizeMMT4d.cpp"
+ "WorkGroupSwizzle.cpp"
DEPS
IREELinalgExtDialect
IREELinalgExtPasses
diff --git a/iree/compiler/Codegen/Common/WorkGroupSwizzle.cpp b/iree/compiler/Codegen/Common/WorkGroupSwizzle.cpp
new file mode 100644
index 0000000..39162c6
--- /dev/null
+++ b/iree/compiler/Codegen/Common/WorkGroupSwizzle.cpp
@@ -0,0 +1,115 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// This function implements the following swizzling logic
+/// void getTiledId2(unsigned x, unsigned y, unsigned* tiledx,
+/// unsigned* tiledy) {
+/// unsigned t_tiledx = (x + (y % tile) * grid_size_x) / tile;
+/// unsigned t_tiledy = (y / tile) * tile +
+/// (x + (y % tile) * grid_size_x) % tile;
+/// bool c = grid_size_y % tile != 0 &&
+/// ((y / tile) * tile + tile) > grid_size_y;
+/// *tiledx = c ? x : t_tiledx;
+/// *tiledy = c ? y : t_tiledy;
+/// }
+// TODO: Make this a callback and the core functionality in the pass a utility
+// function.
+static void makeSwizzledId(Location loc, OpBuilder b, Value workgroupIdX,
+ Value workgroupIdY, Value gridSizeX, Value gridSizeY,
+ Value& SwizzledIdX, Value& SwizzledIdY,
+ unsigned swizzleTile) {
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value tile = b.create<arith::ConstantIndexOp>(loc, swizzleTile);
+ Value yModTile = b.create<arith::RemUIOp>(loc, workgroupIdY, tile);
+ Value yDivTile = b.create<arith::DivUIOp>(loc, workgroupIdY, tile);
+ Value swizzleParam = b.create<arith::MulIOp>(loc, yModTile, gridSizeX);
+ Value swizzleParam2 =
+ b.create<arith::AddIOp>(loc, workgroupIdX, swizzleParam);
+ Value swizzleParam3 = b.create<arith::RemUIOp>(loc, swizzleParam2, tile);
+ Value swizzleParam4 = b.create<arith::MulIOp>(loc, yDivTile, tile);
+ Value unboundedSwizzledIdX =
+ b.create<arith::DivUIOp>(loc, swizzleParam2, tile);
+ Value unboundedSwizzledIdY =
+ b.create<arith::AddIOp>(loc, swizzleParam3, swizzleParam4);
+ Value gyModTile = b.create<arith::RemUIOp>(loc, gridSizeY, tile);
+ Value gyAddTile = b.create<arith::AddIOp>(loc, swizzleParam4, tile);
+ Value condition1 =
+ b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, gyModTile, zero);
+ Value condition2 = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+ gyAddTile, gridSizeY);
+ Value condition3 = b.create<arith::AndIOp>(loc, condition1, condition2);
+ SwizzledIdX = b.create<arith::SelectOp>(loc, condition3, workgroupIdX,
+ unboundedSwizzledIdX);
+ SwizzledIdY = b.create<arith::SelectOp>(loc, condition3, workgroupIdY,
+ unboundedSwizzledIdY);
+}
+namespace {
+struct WorkGroupSwizzlePass
+ : public WorkGroupSwizzleBase<WorkGroupSwizzlePass> {
+ WorkGroupSwizzlePass(unsigned swizzleLogTile)
+ : swizzleLogTile(swizzleLogTile) {}
+
+ void getDependentDialects(DialectRegistry& registry) const override {
+ registry.insert<AffineDialect>();
+ }
+ LogicalResult initializeOptions(StringRef options) override {
+ if (failed(Pass::initializeOptions(options))) {
+ return failure();
+ }
+ swizzleLogTile = logTile;
+ return success();
+ }
+ void runOnOperation() override {
+ if (swizzleLogTile == 0) return;
+ unsigned swizzleTile = pow(2, swizzleLogTile);
+ FuncOp funcOp = getOperation();
+ std::array<IREE::HAL::InterfaceWorkgroupIDOp, 2> oldWorkgroupIds;
+ bool xFound = false, yFound = false;
+ funcOp.walk([&](IREE::HAL::InterfaceWorkgroupIDOp idOp) {
+ unsigned index = idOp.dimension().getZExtValue();
+ if (index == 0) {
+ oldWorkgroupIds[index] = idOp;
+ xFound = true;
+ } else if (index == 1) {
+ oldWorkgroupIds[index] = idOp;
+ yFound = true;
+ }
+ });
+ if (xFound == false || yFound == false) return;
+ OpBuilder builder(funcOp);
+ builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
+ Value workgroupIdX =
+ builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 0);
+ Value workgroupIdY =
+ builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 1);
+ Value gridSizeX = builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(
+ funcOp.getLoc(), 0);
+ Value gridSizeY = builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(
+ funcOp.getLoc(), 1);
+ Value SwizzledIdX, SwizzledIdY;
+ makeSwizzledId(funcOp.getLoc(), builder, workgroupIdX, workgroupIdY,
+ gridSizeX, gridSizeY, SwizzledIdX, SwizzledIdY, swizzleTile);
+ oldWorkgroupIds[0].replaceAllUsesWith(SwizzledIdX);
+ oldWorkgroupIds[1].replaceAllUsesWith(SwizzledIdY);
+ }
+
+ private:
+ unsigned swizzleLogTile;
+};
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
+ unsigned swizzleLogTile) {
+ return std::make_unique<WorkGroupSwizzlePass>(swizzleLogTile);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index f9842b7..d6abde2 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -34,6 +34,7 @@
"remove_dead_allocs.mlir",
"remove_trivial_loops.mlir",
"rewrite_linalg_destructive_updates.mlir",
+ "swizzle_workgroup.mlir",
"tile_and_distribute_to_workgroups.mlir",
"transpose_canonicalization.mlir",
"type_propagation.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index a372c10..f45b163 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -29,6 +29,7 @@
"remove_dead_allocs.mlir"
"remove_trivial_loops.mlir"
"rewrite_linalg_destructive_updates.mlir"
+ "swizzle_workgroup.mlir"
"tile_and_distribute_to_workgroups.mlir"
"transpose_canonicalization.mlir"
"type_propagation.mlir"
diff --git a/iree/compiler/Codegen/Common/test/swizzle_workgroup.mlir b/iree/compiler/Codegen/Common/test/swizzle_workgroup.mlir
new file mode 100644
index 0000000..73e70bd
--- /dev/null
+++ b/iree/compiler/Codegen/Common/test/swizzle_workgroup.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-opt --iree-workgroup-swizzle='logTile=3' %s | FileCheck %s
+
+func @matmul() {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c96 = arith.constant 96 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x4096xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:4096x96xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x96xf32>
+ %3 = linalg.init_tensor [128, 96] : tensor<128x96xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_y]
+ scf.for %arg0 = %4 to %c128 step %5 {
+ %6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
+ scf.for %arg1 = %6 to %c96 step %7 {
+ %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [32, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x4096xf32> -> tensor<32x4096xf32>
+ %9 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:4096x96xf32> -> tensor<4096x32xf32>
+ %10 = tensor.extract_slice %3[%arg0, %arg1] [32, 32] [1, 1] : tensor<128x96xf32> to tensor<32x32xf32>
+ %11 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%8, %9 : tensor<32x4096xf32>, tensor<4096x32xf32>) outs(%10 : tensor<32x32xf32>) -> tensor<32x32xf32>
+ flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1], sizes = [32, 32], strides = [1, 1] : tensor<32x32xf32> -> !flow.dispatch.tensor<writeonly:128x96xf32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @matmul
+// CHECK: %[[WORKGROUPIDX:.*]] = hal.interface.workgroup.id[0] : index
+// CHECK: %[[WORKGROUPIDY:.*]] = hal.interface.workgroup.id[1] : index
+// CHECK: %[[WORKGROUPCOUNTX:.*]] = hal.interface.workgroup.count[0] : index
+// CHECK: %[[WORKGROUPCOUNTY:.*]] = hal.interface.workgroup.count[1] : index
+// CHECK: %[[CST0:.*]] = arith.constant 0 : index
+// CHECK: %[[CST8:.*]] = arith.constant 8 : index
+// CHECK: %[[S0:.*]] = arith.remui %[[WORKGROUPIDY]], %[[CST8]] : index
+// CHECK: %[[S1:.*]] = arith.divui %[[WORKGROUPIDY]], %[[CST8]] : index
+// CHECK: %[[S2:.*]] = arith.muli %[[S0]], %[[WORKGROUPCOUNTX]] : index
+// CHECK: %[[S3:.*]] = arith.addi %[[WORKGROUPIDX]], %[[S2]] : index
+// CHECK: %[[S4:.*]] = arith.remui %[[S3]], %[[CST8]] : index
+// CHECK: %[[S5:.*]] = arith.muli %[[S1]], %[[CST8]] : index
+// CHECK: %[[S6:.*]] = arith.divui %[[S3]], %[[CST8]] : index
+// CHECK: %[[S7:.*]] = arith.addi %[[S4]], %[[S5]] : index
+// CHECK: %[[S8:.*]] = arith.remui %[[WORKGROUPCOUNTY]], %[[CST8]] : index
+// CHECK: %[[S9:.*]] = arith.addi %[[S5]], %[[CST8]] : index
+// CHECK: %[[S10:.*]] = arith.cmpi ne, %[[S8]], %[[CST0]] : index
+// CHECK: %[[S11:.*]] = arith.cmpi ugt, %[[S9]], %[[WORKGROUPCOUNTY]] : index
+// CHECK: %[[S12:.*]] = arith.andi %[[S10]], %[[S11]] : i1
+// CHECK: %[[S13:.*]] = arith.select %[[S12]], %[[WORKGROUPIDX]], %[[S6]] : index
+// CHECK: %[[S14:.*]] = arith.select %[[S12]], %[[WORKGROUPIDY]], %[[S7]] : index
+
+
+
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index c313f8b..a70bc33 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -29,6 +29,10 @@
llvm::cl::desc("Pipeline depth"),
llvm::cl::init(4));
+static llvm::cl::opt<unsigned> logSwizzleTile(
+ "iree-codegen-log-swizzle-tile", llvm::cl::desc("log swizzle tile value"),
+ llvm::cl::init(0));
+
static Value gpuAllocationFunction(OpBuilder &builder, Location loc,
ArrayRef<int64_t> staticShape,
Type elementType,
@@ -91,6 +95,9 @@
pm.addNestedPass<func::FuncOp>(
createLLVMGPUReduceSharedMemoryBankConflicts());
pm.addNestedPass<func::FuncOp>(createRemoveSingleIterationLoopPass());
+ pm.addNestedPass<func::FuncOp>(createWorkGroupSwizzle(logSwizzleTile));
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
// Linalg -> vector
pm.addNestedPass<func::FuncOp>(createLLVMGPUVectorizationPass());
@@ -118,6 +125,9 @@
pm.addNestedPass<func::FuncOp>(
createLLVMGPUReduceSharedMemoryBankConflicts());
pm.addNestedPass<func::FuncOp>(createRemoveSingleIterationLoopPass());
+ pm.addNestedPass<func::FuncOp>(createWorkGroupSwizzle(logSwizzleTile));
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
// Linalg -> vector
pm.addNestedPass<func::FuncOp>(createLLVMGPUTensorCoreVectorizationPass());
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index fa25554..a89bbb6 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -153,6 +153,10 @@
/// Creates a pass to convert memref.copy to linalg op.
std::unique_ptr<OperationPass<func::FuncOp>> createMemrefCopyToLinalgPass();
+/// Converts vector ops to gpu dialect.
+std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
+ unsigned swizzleLogTile = 0);
+
//----------------------------------------------------------------------------//
// Common codegen patterns.
//----------------------------------------------------------------------------//
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index eab90a1..d4f23b6 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -162,6 +162,18 @@
let constructor =
"mlir::iree_compiler::createMemrefCopyToLinalgPass()";
}
+
+def WorkGroupSwizzle :
+ Pass<"iree-workgroup-swizzle", "func::FuncOp"> {
+ let summary = "swizzle the workgroup ids for better cache reuse";
+ let constructor = "mlir::iree_compiler::createWorkGroupSwizzle()";
+ let options = [
+ Option<"logTile", "logTile", "unsigned",
+ /*default=*/"0",
+ "pass the tile value for unit testing">,
+ ];
+}
+
//------------------------------------------------------------------------------
// LLVMCPU
//------------------------------------------------------------------------------