[spirv] Add patterns for folding processor ID uses (#3877)
For processor IDs we have their upper bound in the workgroup
count function or SPIR-V entry point ABI, which can be used
to fold processor ID uses where possible.
This commit add a pattern to fold affine.min ops using processor
IDs. It allows us to get static constant value, which can then
be used for memref subviews.
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
index 01e4b3d..5495c87 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -15,6 +15,7 @@
#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_GETNUMWORKGROUPS_H_
#define IREE_COMPILER_CONVERSION_CODEGENUTILS_GETNUMWORKGROUPS_H_
+#include <array>
#include <cstdint>
namespace llvm {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index d61e74b..0156b47 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -37,6 +37,7 @@
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"CooperativeMatrixAnalysis.cpp",
+ "FoldGPUProcessorIDUses.cpp",
"KernelDispatchUtils.cpp",
"LinalgTileAndFusePass.cpp",
"MatMulVectorizationTest.cpp",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index b10d22b..9368a9d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -39,6 +39,7 @@
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"CooperativeMatrixAnalysis.cpp"
+ "FoldGPUProcessorIDUses.cpp"
"KernelDispatchUtils.cpp"
"LinalgTileAndFusePass.cpp"
"MatMulVectorizationTest.cpp"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
new file mode 100644
index 0000000..1dda99d
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
@@ -0,0 +1,266 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- FoldGPUProcessorIDUses.cpp -----------------------------------------===//
+//
+// This file implements patterns and passes for folding GPU processor ID uses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-fold-gpu-procid-uses"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Returns true if the given `expr` is a linear expression over one
+/// symbol/dimension.
+///
+/// Note that this function is not meant to check for all linear expression
+/// cases. It only checks that:
+/// 1) No mod/div operations,
+/// 2) For mul operations, one of the operand is a constant.
+/// Also this function assumes `expr` only contains one symbol/dimension.
+bool isLinearExpr(AffineExpr expr) {
+ switch (expr.getKind()) {
+ case mlir::AffineExprKind::Add: {
+ auto binExpr = expr.cast<AffineBinaryOpExpr>();
+ return isLinearExpr(binExpr.getLHS()) && isLinearExpr(binExpr.getRHS());
+ }
+ case mlir::AffineExprKind::Mul: {
+ auto binExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineExpr lhs = binExpr.getLHS();
+ AffineExpr rhs = binExpr.getRHS();
+ return (lhs.isa<AffineConstantExpr>() && isLinearExpr(rhs)) ||
+ (rhs.isa<AffineConstantExpr>() && isLinearExpr(lhs));
+ };
+ case mlir::AffineExprKind::Mod:
+ case mlir::AffineExprKind::FloorDiv:
+ case mlir::AffineExprKind::CeilDiv:
+ return false;
+ case mlir::AffineExprKind::Constant:
+ case mlir::AffineExprKind::DimId:
+ case mlir::AffineExprKind::SymbolId:
+ return true;
+ }
+}
+
+/// Replaces the given `dim` in `expr` with a constant `value`.
+AffineExpr replaceSymbolWithValue(AffineExpr expr, AffineSymbolExpr dim,
+ int64_t value) {
+ auto cstExpr = getAffineConstantExpr(value, expr.getContext());
+ return expr.replace(dim, cstExpr);
+}
+
+/// Converts a dimension string to its corresponding index.
+int dimensionToIndex(StringRef dimension) {
+ return StringSwitch<int>(dimension).Case("x", 0).Case("y", 1).Case("z", 2);
+}
+
+/// Gets the block processor ID's upper bound. This queries the workgroup count
+/// function.
+Optional<int64_t> getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) {
+ auto numWorkgroupsFn = getNumWorkgroupsFn(blockIDOp.getParentOfType<FuncOp>(),
+ getNumWorkgroupsFnAttrName());
+ if (!numWorkgroupsFn) return llvm::None;
+
+ Operation *terminator = numWorkgroupsFn.getBlocks().back().getTerminator();
+ auto retOp = dyn_cast<ReturnOp>(terminator);
+ if (!retOp || retOp.getNumOperands() != 3) return llvm::None;
+ LLVM_DEBUG(llvm::dbgs() << "workgroup count function return op: " << retOp
+ << "\n");
+
+ int index = dimensionToIndex(blockIDOp.dimension());
+ IntegerAttr attr;
+ if (!matchPattern(retOp.getOperand(index), m_Constant(&attr)))
+ return llvm::None;
+
+ return attr.getInt();
+}
+
+/// Gets the thread processor ID's upper bound. This queries the SPIR-V entry
+/// point ABI.
+Optional<int64_t> getProcessorIDUpperBound(gpu::ThreadIdOp threadIDOp) {
+ FuncOp funcOp = threadIDOp.getParentOfType<FuncOp>();
+ auto abiAttr = funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
+ spirv::getEntryPointABIAttrName());
+ if (!abiAttr) return llvm::None;
+
+ int index = dimensionToIndex(threadIDOp.dimension());
+ auto valueIt = abiAttr.local_size().getIntValues().begin() + index;
+ return (*valueIt).getZExtValue();
+}
+
+/// Folds `affine.min` ops which has only one symbol operand, which is a
+/// processor ID. For such cases we can use the processor ID's upper bound to
+/// simplify the `affine.min`.
+///
+/// For example, this pattern can simplify the following IR:
+/// ```
+/// %id = "gpu.thread_id"() {dimension = "x"} : () -> index
+/// ...
+/// affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%id]
+/// ```
+/// if the upper bound for thread ID along the x dimension is 112.
+struct FoldAffineMinOverProcessorID : OpRewritePattern<AffineMinOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineMinOp minOp,
+ PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "inspecting " << minOp << "\n");
+ MLIRContext *context = minOp.getContext();
+ auto dimensions = minOp.getDimOperands();
+ auto symbols = minOp.getSymbolOperands();
+
+ // We expect the affine.min op to only have one symbol operand.
+ if (!llvm::hasSingleElement(symbols) || !dimensions.empty()) {
+ return rewriter.notifyMatchFailure(
+ minOp, "expected to only have one symbol operand");
+ }
+
+ // And the symbol operand should come from a GPU processor ID.
+ Operation *symbolOp = symbols.front().getDefiningOp();
+ auto symbol0 = getAffineSymbolExpr(0, context).cast<AffineSymbolExpr>();
+
+ Optional<int64_t> ub;
+ if (auto blockIDOp = dyn_cast<gpu::BlockIdOp>(symbolOp)) {
+ ub = getProcessorIDUpperBound(blockIDOp);
+ } else if (auto threadIDOp = dyn_cast<gpu::ThreadIdOp>(symbolOp)) {
+ ub = getProcessorIDUpperBound(threadIDOp);
+ }
+ if (!ub) {
+ return rewriter.notifyMatchFailure(
+ minOp, "failed to query processor ID upper bound");
+ }
+ LLVM_DEBUG(llvm::dbgs() << "processor ID '" << *symbolOp
+ << "' upper bound: " << *ub << "\n");
+
+ // Look at each result expression. For expressions that are functions of
+ // the input symbol, try to simplify it. We do this by replacing the
+ // symbol with its lower and upper bound. This requires the result
+ // expression to be a linear function of the input symbol.
+ SmallVector<AffineExpr, 4> results;
+ // The indices into `results` where the corresponding AffineExpr is a
+ // constant from the original map. We need to keep track of this so later we
+ // can probe whether the constant is the min.
+ SmallVector<unsigned, 4> cstIndices;
+ for (auto result : minOp.getAffineMap().getResults()) {
+ if (auto cstResult = result.dyn_cast<AffineConstantExpr>()) {
+ results.push_back(cstResult);
+ cstIndices.push_back(results.size() - 1);
+ } else if (isLinearExpr(result)) {
+ results.push_back(simplifyAffineExpr(
+ replaceSymbolWithValue(result, symbol0, 0), 0, 1));
+ results.push_back(simplifyAffineExpr(
+ replaceSymbolWithValue(result, symbol0, *ub - 1), 0, 1));
+ LLVM_DEBUG({
+ auto map = AffineMap::get(0, 1, results, context);
+ llvm::dbgs() << "map after substituting with processor ID bounds: "
+ << map << "\n";
+ });
+ } else {
+ // We cannot handle such cases. Just bail out on matching the pattern.
+ return rewriter.notifyMatchFailure(
+ minOp, "expected to have a linear function of the symbol");
+ }
+ }
+
+ // Returns true if the given affine expression is a non-negative constant.
+ auto isNonNegativeCstExpr = [](AffineExpr e) {
+ if (auto cst = e.dyn_cast<AffineConstantExpr>())
+ return cst.getValue() >= 0;
+ return false;
+ };
+
+ // Check whether any of the original constant expressions, when subtracted
+ // from all other expressions, produces only >= 0 constants. If so, it is
+ // the min.
+ for (auto cstIndex : cstIndices) {
+ auto candidate = results[cstIndex].cast<AffineConstantExpr>();
+
+ SmallVector<AffineExpr, 4> subExprs;
+ subExprs.reserve(results.size());
+ for (auto r : results) subExprs.push_back(r - candidate);
+
+ AffineMap subMap =
+ simplifyAffineMap(AffineMap::get(0, 1, subExprs, context));
+ LLVM_DEBUG(llvm::dbgs() << "map by subtracting expr '" << candidate
+ << "': " << subMap << "\n");
+ if (llvm::all_of(subMap.getResults(), isNonNegativeCstExpr)) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp,
+ candidate.getValue());
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
+/// Tests processor ID use folding patterns.
+struct FoldGPUProcessIDUsesPass
+ : public PassWrapper<FoldGPUProcessIDUsesPass, FunctionPass> {
+ FoldGPUProcessIDUsesPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, gpu::GPUDialect>();
+ }
+
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns;
+ populateFoldGPUProcessorIDUsesPatterns(context, patterns);
+ applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ }
+};
+
+}; // namespace
+
+void populateFoldGPUProcessorIDUsesPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<FoldAffineMinOverProcessorID>(context);
+ AffineMinOp::getCanonicalizationPatterns(patterns, context);
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createFoldProcessorIDUsesPass() {
+ return std::make_unique<FoldGPUProcessIDUsesPass>();
+}
+
+static PassRegistration<FoldGPUProcessIDUsesPass> pass(
+ "iree-codegen-fold-gpu-procid-uses",
+ "Fold GPU processor ID uses where possible",
+ [] { return std::make_unique<FoldGPUProcessIDUsesPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 7c650ca..482d4ae 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -24,6 +24,10 @@
namespace mlir {
namespace iree_compiler {
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
/// Pass to tile and fuse linalg operations on buffers. The pass takes as
/// argument the `workgroupSize` that the tiling should use. Note that the
/// tile-sizes are the reverse of the workgroup size. So workgroup size along
@@ -59,11 +63,18 @@
/// Pass to apply tiling and vectorization transformations on linagl::MatMulOp.
std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass();
-/// Convert memref of scalar to memref of vector of efficent size. This will
+/// Converts memref of scalar to memref of vector of efficent size. This will
/// allow to convert memory accesses to vector load/store in SPIR-V without
/// having pointer bitcast.
std::unique_ptr<OperationPass<ModuleOp>> createVectorizeMemref();
+/// Creates a pass to fold processor ID uses where possible.
+std::unique_ptr<OperationPass<FuncOp>> createFoldProcessorIDUsesPass();
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
/// structured ops path. The pass manager `pm` in here operate on the module
/// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
@@ -73,9 +84,19 @@
void buildSPIRVTransformPassPipeline(OpPassManager &pm,
const SPIRVCodegenOptions &options);
-/// Populate patterns to tile and distribute linalg operations.
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+/// Populates patterns to tile and distribute linalg operations.
void populateLinalgTileAndDistributePatterns(
MLIRContext *context, OwningRewritePatternList &patterns);
+
+/// Populates patterns to fold processor ID uses by using processor counts
+/// information where possible.
+void populateFoldGPUProcessorIDUsesPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir
new file mode 100644
index 0000000..1d45673
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir
@@ -0,0 +1,101 @@
+// RUN: iree-opt -split-input-file -iree-codegen-fold-gpu-procid-uses %s | IreeFileCheck %s
+
+module {
+ // CHECK-LABEL: func @fold_block_id_x()
+ func @fold_block_id_x() -> index attributes {hal.num_workgroups_fn = @num_workgroups} {
+ // CHECK: %[[cst:.+]] = constant 3
+ // CHECK: return %[[cst]]
+ %0 = "gpu.block_id"() {dimension = "x"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%0]
+ return %1: index
+ }
+
+ // CHECK-LABEL: func @fold_block_id_y()
+ func @fold_block_id_y() -> index attributes {hal.num_workgroups_fn = @num_workgroups} {
+ // CHECK: %[[cst:.+]] = constant 8
+ // CHECK: return %[[cst]]
+ %0 = "gpu.block_id"() {dimension = "y"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (8, s0 * -1 + s0 * -1 + s0 * -1 + 131)>()[%0]
+ return %1: index
+ }
+
+ // CHECK-LABEL: func @fold_block_id_z()
+ func @fold_block_id_z() -> index attributes {hal.num_workgroups_fn = @num_workgroups} {
+ // CHECK: %[[cst:.+]] = constant 11
+ // CHECK: return %[[cst]]
+ %0 = "gpu.block_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (11, s0 + 15)>()[%0]
+ return %1: index
+ }
+
+ func @num_workgroups() -> (index, index, index) {
+ %x = constant 112: index
+ %y = constant 42: index
+ %z = constant 1: index
+ return %x, %y, %z: index, index, index
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_thread_id_x()
+func @fold_thread_id_x() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: %[[cst:.+]] = constant 7
+ // CHECK: return %[[cst]]
+ %0 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (7, s0 * -1 + s0 * -1 + 21)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @fold_thread_id_y()
+func @fold_thread_id_y() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: %[[cst:.+]] = constant 11
+ // CHECK: return %[[cst]]
+ %0 = "gpu.thread_id"() {dimension = "y"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (11, s0 * -3 + 14)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @fold_thread_id_z()
+func @fold_thread_id_z() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: %[[cst:.+]] = constant 21
+ // CHECK: return %[[cst]]
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 + (s0 + 21))>()[%0]
+ return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @does_not_fold_mod()
+func @does_not_fold_mod() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 mod 5)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @does_not_fold_div()
+func @does_not_fold_div() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 ceildiv 5)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @does_not_fold_symbol_mul_symbol()
+func @does_not_fold_symbol_mul_symbol() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 * s0)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @does_not_fold_if_cst_not_lower_bound()
+func @does_not_fold_if_cst_not_lower_bound() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ // 5 is in %0's range of [0,7] so we cannot fold the following into 5 or 0.
+ %1 = affine.min affine_map<()[s0] -> (5, s0)>()[%0]
+ return %1: index
+}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 3b8e8bb..6d72146 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -47,6 +47,7 @@
static bool init_once = []() {
// LinalgToSPIRV
createConvertToGPUPass();
+ createFoldProcessorIDUsesPass();
createLinalgTileAndFusePass(SPIRVCodegenOptions());
createSplitDispatchFunctionPass();
createVectorToGPUPass();