blob: bdc86308a4a7deb4d5239b99511e865aea0d4fb6 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/target/CUDA/SetBlockIdsRangePass.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
// This is a workaround until nvvm-intr-range gets re-enabled by default in
// NVPTX backend. This also allows us to potentially special more the kernel by
// setting more fine grain ranges based on static dispatch size.
using namespace llvm;
#define DEBUG_TYPE "iree-dialect-hal-cuda-llvm-set-block-ids-range"
// Adds the passed-in [Low,High) range information as metadata to the
// passed-in call instruction.
static bool addRangeMetadata(uint64_t Low, uint64_t High, CallInst *C) {
// This call already has range metadata, nothing to do.
if (C->getMetadata(LLVMContext::MD_range))
return false;
LLVMContext &Context = C->getParent()->getContext();
IntegerType *Int32Ty = Type::getInt32Ty(Context);
Metadata *LowAndHigh[] = {
ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Low)),
ConstantAsMetadata::get(ConstantInt::get(Int32Ty, High))};
C->setMetadata(LLVMContext::MD_range, MDNode::get(Context, LowAndHigh));
return true;
}
static bool runOnFunction(Function &F,
const std::array<int32_t, 3> &maxWorkgroupSize) {
bool Changed = false;
// We could use the number of block dispatched if it is known at compile time
// however this would prevent re-using kernel re-use. For now just use the API
// limit.
unsigned MaxGridSizeX = 0x7fffffff;
unsigned MaxGridSizeY = 0xffff;
unsigned MaxGridSizeZ = 0xffff;
for (Instruction &I : instructions(F)) {
CallInst *Call = dyn_cast<CallInst>(&I);
if (!Call)
continue;
Function *Callee = Call->getCalledFunction();
if (!Callee)
continue;
switch (Callee->getIntrinsicID()) {
// Index within block
case Intrinsic::nvvm_read_ptx_sreg_tid_x:
Changed |= addRangeMetadata(0, maxWorkgroupSize[0], Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_tid_y:
Changed |= addRangeMetadata(0, maxWorkgroupSize[1], Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_tid_z:
Changed |= addRangeMetadata(0, maxWorkgroupSize[2], Call);
break;
// Block size
case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
Changed |= addRangeMetadata(1, maxWorkgroupSize[0] + 1, Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
Changed |= addRangeMetadata(1, maxWorkgroupSize[1] + 1, Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
Changed |= addRangeMetadata(1, maxWorkgroupSize[2] + 1, Call);
break;
// Index within grid
case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
Changed |= addRangeMetadata(0, MaxGridSizeX, Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
Changed |= addRangeMetadata(0, MaxGridSizeY, Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
Changed |= addRangeMetadata(0, MaxGridSizeZ, Call);
break;
// Grid size
case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
Changed |= addRangeMetadata(1, MaxGridSizeX + 1, Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
Changed |= addRangeMetadata(1, MaxGridSizeY + 1, Call);
break;
case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
Changed |= addRangeMetadata(1, MaxGridSizeZ + 1, Call);
break;
}
}
return Changed;
}
PreservedAnalyses SetBlockIdsRangePass::run(Function &F,
FunctionAnalysisManager &AM) {
return runOnFunction(F, maxWorkgroupSize) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}